bitorch_engine.layers.qlinear.nbit.cutlass.q8_layer.Q8LinearFunction
- class bitorch_engine.layers.qlinear.nbit.cutlass.q8_layer.Q8LinearFunction(*args, **kwargs)[source]
Implements a quantized linear function using 8-bit quantization for both activations and weights.
This class is designed to perform forward and backward passes of a linear layer with quantization, leveraging CUTLASS kernels for efficient computation. The quantization scheme is based on fixed-point representation, where ‘scale_a’ and ‘scale_w’ are scaling factors for activations and weights, respectively.
Methods
Backward pass of the quantized linear function.
Forward pass of the quantized linear function.
Attributes
- static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...] [source]
Backward pass of the quantized linear function.
- Parameters:
ctx – Context object containing saved tensors from the forward pass.
output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of this layer.
- Returns:
Tuple containing gradients with respect to the input, weight, scale_a, and None placeholders for scale_w, eps, and is_train which do not require gradients.
- static forward(ctx, x: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, eps: Tensor, is_train: bool) Tensor [source]
Forward pass of the quantized linear function.
- Parameters:
ctx – Context object to save variables for backward computation.
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Weight tensor.
scale_a (torch.Tensor) – Scaling factor for the activations.
scale_w (torch.Tensor) – Scaling factor for the weights.
eps (torch.Tensor) – Epsilon tensor for quantization to avoid division by zero.
is_train (bool) – Flag indicating whether the forward pass is for training or inference.
- Returns:
The output tensor of the quantized linear operation.
- Return type:
torch.Tensor