bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4LinearCutlass

class bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4LinearCutlass(*args, **kwargs)[source]

This class implements a quantized linear layer using the CUTLASS library, specifically designed for 4-bit quantization.

bias_a

Bias parameter for the activations, initialized to zeros.

Type:

torch.nn.Parameter

scale_a

Scale parameter for the activation quantization, initialized to zero.

Type:

torch.nn.Parameter

scale_w

Scale parameter for the weight quantization, initialized to one.

Type:

torch.nn.Parameter

eps

A small epsilon value used to avoid division by zero, registered as a buffer.

Type:

torch.Tensor

The class inherits from nBitLinearBase, extending it with 4-bit quantization capabilities. It introduces parameters and methods for managing and applying quantization to both weights and activations within a linear (fully connected) layer context.

The quantization process involves scaling floating-point weights to a 4-bit representation, which significantly reduces memory usage and computational cost, especially on hardware that supports low-precision arithmetic. This class is designed to work with neural network models where efficiency and speed are critical, such as on edge devices or in high-performance computing environments.

prepare_params()[source]

function will be called between initialization, checkpoint loading and the actual forward pass.

generate_quantized_weight()[source]

Quantizes the weights and optionally removes the floating-point weights to save memory.

_check_forward()[source]

Checks that the input dimensions are compatible with the weight dimensions and certain constraints are met.

set_activation()[source]

Quantizes the activation to 8-bit representation.

forward()[source]

Defines the computation performed at every call.

Methods

__init__

Initializes the Q4LinearCutlass layer, setting up parameters for activation and weight quantization.

forward

Forward pass of the layer.

generate_quantized_weight

Performs weight quantization.

prepare_params

Prepares and initializes the model parameters for training.

set_activation

Quantizes the activation to 8-bit by applying a scaling factor and optionally adds a learnable bias.

Attributes

training

__init__(*args, **kwargs)[source]

Initializes the Q4LinearCutlass layer, setting up parameters for activation and weight quantization.

forward(x: Tensor) Tensor[source]

Forward pass of the layer. It performs necessary pre-checks on the input tensor, quantizes the activations, and then applies the quantized linear function using the quantized weights.

Parameters:

x (torch.Tensor) – The input tensor with shape (batch size, number of features).

Returns:

The output tensor after applying the quantized linear transformation.

Return type:

torch.Tensor

generate_quantized_weight(qweight_only: bool = False) None[source]

Performs weight quantization. This method should be called before saving the model’s weights to ensure that the weights are quantized. If qweight_only is set to True, the original weights are discarded to save memory, keeping only the quantized weights.

Parameters:

qweight_only (bool) – If True, retains only the quantized weights and discards the original weights.

prepare_params() None[source]

Prepares and initializes the model parameters for training.

Note

This method MUST be called after model initialization and before training starts to ensure the weights are properly prepared for efficient computation.

One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.

set_activation(x: Tensor) Tensor[source]

Quantizes the activation to 8-bit by applying a scaling factor and optionally adds a learnable bias. The scaling factor is computed based on the input tensor’s mean absolute value, considering the quantization precision.

Parameters:

x (torch.Tensor) – The input activation tensor.

Returns:

The quantized activation tensor.

Return type:

torch.Tensor