bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4LinearFunction

class bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4LinearFunction(*args, **kwargs)[source]

Implements a custom linear function with quantization for forward and backward passes. This function specifically supports 4-bit quantization for activations and weights during the forward pass, and provides gradients for input, weights, and scale factors during the backward pass.

Note: Both forward and backward methods are static methods.

The forward pass quantizes inputs and weights to 4-bits using specified scale factors and performs linear (fully connected) operation. The backward pass computes gradients for the quantized inputs, weights, and scale factors, considering the quantization effects.

Methods

backward

Backward pass of the Q4Linear function.

forward

Forward pass of the Q4Linear function.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass of the Q4Linear function.

Computes gradients for the input tensor, weight tensor, and scale factors based on the output gradient. Adjusts gradients based on quantization ranges and effects to ensure proper gradient flow for quantized operations.

Parameters:
  • ctx – Context object with saved tensors from the forward pass.

  • output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this function.

Returns:

Tuple containing gradients for input tensor, weight tensor, scale factor for activation,

and None placeholders for scale_w and eps which do not receive gradients directly.

Return type:

Tuple[torch.Tensor, …]

static forward(ctx, x: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, eps: Tensor, is_train: bool) Tensor[source]

Forward pass of the Q4Linear function.

Parameters:
  • ctx – Context object to save variables for backward computation.

  • x (torch.Tensor) – Input tensor.

  • weight (torch.Tensor) – Weight tensor.

  • scale_a (torch.Tensor) – Scale factor for input quantization.

  • scale_w (torch.Tensor) – Scale factor for weight quantization.

  • eps (torch.Tensor) – Epsilon tensor for quantization to avoid division by zero.

  • is_train (bool) – Flag indicating if the model is in training mode.

Returns:

The output tensor after applying the linear operation on quantized inputs and weights.

Return type:

torch.Tensor