bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMulFunction

class bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMulFunction(*args, **kwargs)[source]

This class implements a custom autograd function for quantized matrix multiplication (MatMul) using 4-bit quantization. It quantizes the inputs to 4 bits, performs the MatMul operation, and then dequantizes the result. This operation is designed to work with low-precision arithmetic to improve computational efficiency while maintaining reasonable accuracy.

Both the forward and backward methods are implemented as static methods, allowing this class to be used directly without instantiation.

Methods

backward

Backward pass of the Q4MatMulFunction.

forward

Forward pass of the Q4MatMulFunction.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass of the Q4MatMulFunction. Computes gradients for the input tensors based on the output gradient.

This method calculates the gradients for both input tensors and their clipping values, taking into account the quantization performed during the forward pass. It uses a custom backward operation defined in q_linear_cutlass to compute these gradients efficiently.

Parameters:
  • ctx – Context object containing saved tensors from the forward pass.

  • output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this function.

Returns:

A tuple containing gradients for x, y, x_clip, y_clip, and None placeholders

for eps and is_train, as these do not require gradients.

Return type:

Tuple[torch.Tensor, …]

static forward(ctx, x: Tensor, y: Tensor, x_clip: Tensor, y_clip: Tensor, eps: Tensor, is_train: bool) Tensor[source]

Forward pass of the Q4MatMulFunction. Quantizes inputs, performs MatMul, and dequantizes the output.

Parameters:
  • ctx – Context object to save tensors for backward computation.

  • x (torch.Tensor) – Input tensor 1.

  • y (torch.Tensor) – Input tensor 2.

  • x_clip (torch.Tensor) – Clipping value for input tensor x, used for quantization range.

  • y_clip (torch.Tensor) – Clipping value for input tensor y, used for quantization range.

  • eps (torch.Tensor) – Epsilon value to avoid division by zero during quantization.

  • is_train (bool) – Flag indicating if the model is in training mode.

Returns:

The result of the quantized MatMul operation.

Return type:

torch.Tensor