bitorch_engine.layers.qlinear.nbit.cutlass.q8_layer.Q8LinearFunction

class bitorch_engine.layers.qlinear.nbit.cutlass.q8_layer.Q8LinearFunction(*args, **kwargs)[source]

Implements a quantized linear function using 8-bit quantization for both activations and weights.

This class is designed to perform forward and backward passes of a linear layer with quantization, leveraging CUTLASS kernels for efficient computation. The quantization scheme is based on fixed-point representation, where ‘scale_a’ and ‘scale_w’ are scaling factors for activations and weights, respectively.

Methods

backward

Backward pass of the quantized linear function.

forward

Forward pass of the quantized linear function.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass of the quantized linear function.

Parameters:
  • ctx – Context object containing saved tensors from the forward pass.

  • output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this layer.

Returns:

Tuple containing gradients with respect to the input, weight, scale_a, and None placeholders for scale_w, eps, and is_train which do not require gradients.

static forward(ctx, x: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, eps: Tensor, is_train: bool) Tensor[source]

Forward pass of the quantized linear function.

Parameters:
  • ctx – Context object to save variables for backward computation.

  • x (torch.Tensor) – Input tensor.

  • weight (torch.Tensor) – Weight tensor.

  • scale_a (torch.Tensor) – Scaling factor for the activations.

  • scale_w (torch.Tensor) – Scaling factor for the weights.

  • eps (torch.Tensor) – Epsilon tensor for quantization to avoid division by zero.

  • is_train (bool) – Flag indicating whether the forward pass is for training or inference.

Returns:

The output tensor of the quantized linear operation.

Return type:

torch.Tensor