bitorch_engine.layers.qlinear.nbit.cuda.mpq_layer.MPQLinearCudaFunction
- class bitorch_engine.layers.qlinear.nbit.cuda.mpq_layer.MPQLinearCudaFunction(*args, **kwargs)[source]
A custom autograd function for mixed-precision quantized (MPQ) linear operations on CUDA.
This function supports forward and backward passes of a linear layer with quantized weights, allowing for efficient computation on CUDA devices. It is specifically designed for scenarios where both activations and weights are quantized to different bit-widths for reduced memory footprint and computational efficiency, particularly useful in low-power and memory-constrained environments such as edge devices.
The forward pass performs quantized matrix multiplication using custom CUDA kernels, while the backward pass computes gradients with respect to the input and updates the privileged gradient of the quantized weights.
Methods
Backward pass for the MPQ linear operation.
Forward pass for the MPQ linear operation.
Attributes
- static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...] [source]
Backward pass for the MPQ linear operation.
Computes gradients with respect to the input and updates the privileged gradient of the quantized weights.
- Parameters:
ctx – Context object with saved tensors from the forward pass.
output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this operation.
Note
A specific optimizer bitorch_engine.optim.mpq_adamw is available in conjunction with this layer.
- Returns:
- A tuple containing gradients with respect to the input and None placeholders for other arguments
that do not require gradients.
- Return type:
tuple
- static forward(ctx, x: Tensor, qweight: Tensor, a_bit: int, w_bit: int, scales: Tensor, zeros: Tensor, g_idx: Tensor, asym: bool, is_training: bool, privileged_grad: Tensor | None = None) Tensor [source]
Forward pass for the MPQ linear operation.
- Parameters:
ctx – Context object to save information for backward computation.
x (torch.Tensor) – Input tensor.
qweight (torch.Tensor) – Quantized weights.
a_bit (int) – Activation quantization bit-width.
w_bit (int) – Weight quantization bit-width.
scales (torch.Tensor) – Quantization scales.
zeros (torch.Tensor) – Quantization zero points for asymmetric quantization.
g_idx (torch.Tensor) – Indices for grouped quantization.
asym (bool) – Flag for asymmetric quantization.
is_training (bool) – Training mode flag.
- Returns:
The result of the quantized linear operation.
- Return type:
torch.Tensor