bitorch_engine.layers.qlinear.nbit.mps.mpq_layer.MPQLinearMlxFunction

class bitorch_engine.layers.qlinear.nbit.mps.mpq_layer.MPQLinearMlxFunction(*args, **kwargs)[source]

A custom autograd function for mixed-precision quantized (MPQ) linear operations on MPS acceleated by mlx.

This function supports forward and backward passes of a linear layer with quantized weights, allowing for efficient computation on MPS devices. It is specifically designed for scenarios where both activations and weights are quantized to different bit-widths for reduced memory footprint and computational efficiency, particularly useful in low-power and memory-constrained environments such as edge devices.

The forward pass performs quantized matrix multiplication using optimized Mlx implementations, while the backward pass computes gradients with respect to the input and updates the privileged gradient of the quantized weights.

Methods

backward

Backward pass for the MPQ linear operation.

forward

Forward pass for the MPQ linear operation.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass for the MPQ linear operation.

Computes gradients with respect to the input and updates the privileged gradient of the quantized weights.

Parameters:
  • ctx – Context object with saved tensors from the forward pass.

  • output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this operation.

Note

A specific optimizer bitorch_engine.optim.mpq_adamw is available in conjunction with this layer.

Returns:

A tuple containing gradients with respect to the input and None placeholders for other arguments

that do not require gradients.

Return type:

tuple

static forward(ctx, x: Tensor, qweight: MPQWeightParameter, w_bit: int, scales: Tensor, zeros: Tensor, group_size: int, is_training: bool) Tensor[source]

Forward pass for the MPQ linear operation. Currently, only symmetric quantization is supported.

Parameters:
  • ctx – Context object to save information for backward computation.

  • x (torch.Tensor) – Input tensor.

  • qweight (torch.Tensor) – Quantized weights.

  • w_bit (int) – Weight quantization bit-width.

  • scales (torch.Tensor) – Quantization scales.

  • zeros (torch.Tensor) – Quantization zero points for asymmetric quantization.

  • group_size (int) – The group size of the quantized weights.

  • is_training (bool) – Training mode flag.

Returns:

The result of the quantized linear operation.

Return type:

torch.Tensor