bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearCutlass
- class bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearCutlass(*args, **kwargs)[source]
A specialized binary linear layer that leverages CUTLASS for efficient low-precision computations.
This class extends BinaryLinearBase, incorporating specific optimizations for binary neural networks. It uses CUTLASS kernels for binary GEMM operations, optimizing the execution on GPUs. The layer supports both training and inference in binary precision, significantly reducing memory footprint and computational costs.
- bits_binary_word
The bit-width of the binary words used in CUTLASS operations, typically set to 8.
- Type:
int
- gemm_kernel_id
Identifier for the CUTLASS kernel to be used for the GEMM operation.
- Type:
int
- bias_a
Layer-wise bias parameter for input activations.
- Type:
torch.nn.Parameter
- scale_a
Scale factor for input activations, aiding in quantization.
- Type:
torch.nn.Parameter
- scale_w
Scale factor for weights, essential for maintaining numerical accuracy.
- Type:
torch.nn.Parameter
- prepare_params()[source]
Prepares and initializes model parameters for training, converting weights to int8 format.
- generate_quantized_weight(qweight_only=False)[source]
Performs bit-packing on 32-bit weights to reduce memory usage.
- set_weight_data(x)[source]
Prepares the weight data for computation, calling prepare_params internally.
- select_gemm_kernel(x)[source]
Evaluates and selects the appropriate CUTLASS kernel based on input dimensions.
- forward(x)[source]
Defines the forward pass for the binary linear layer, leveraging CUTLASS for efficient computation.
Methods
Initializes the BinaryLinearBase class with specified configurations.
Defines the forward pass of the binary linear layer.
Performs bit-packing on the 32-bit floating-point weights to reduce the model's memory footprint.
Prepares and initializes the model parameters for training, specifically converting floating-point weights to int8 format.
Selects the most appropriate GEMM kernel from the available CUTLASS kernels for the binary operation.
Normalizes input activations using a layer-wise scale factor and adds bias.
Prepares the weight data for the binary linear operation.
Attributes
training
- __init__(*args, **kwargs)[source]
Initializes the BinaryLinearBase class with specified configurations.
- Parameters:
input_features (int) – Dimension of input features after bit-packing.
out_features (int) – Dimension of output features or hidden states.
device (torch.device, optional) – Device on which to allocate tensors. Defaults to None.
dtype (torch.dtype, optional) – Data type for floating-point weights. Defaults to torch.float.
symmetric (bool, optional) – If True, quantization is symmetric around 0. Defaults to True.
- forward(x: Tensor) Tensor [source]
Defines the forward pass of the binary linear layer.
This method applies normalization and bias to the input activations, selects the appropriate GEMM kernel, and performs the binary linear operation using the optimized CUTLASS kernel.
- Parameters:
x (torch.Tensor) – Input tensor with shape (batch size, number of input features).
- Returns:
The output of the binary linear operation, ready for further processing in the network.
- generate_quantized_weight(qweight_only: bool = False) None [source]
Performs bit-packing on the 32-bit floating-point weights to reduce the model’s memory footprint.
This method converts the full-precision weights to quantized format, specifically designed for binary linear operations. It facilitates efficient computation on hardware that supports binary operations by reducing the weight representation to 8 bits.
- Parameters:
qweight_only (bool) – If True, the original floating-point weights are discarded to save memory, leaving only the quantized weights.
Note
The quantized weights are stored as a new parameter qweight within the class.
- prepare_params() None [source]
Prepares and initializes the model parameters for training, specifically converting floating-point weights to int8 format.
This method leverages the init_weight function to convert the model’s floating-point weights to int8, achieving a significant reduction in memory usage. It also computes a scale for the weights, which is essential for maintaining the numerical fidelity of the model’s computations in the lower precision format. The conversion to int8 format is particularly beneficial for accelerating training and inference on hardware that supports lower precision arithmetic.
Note
This method MUST be called after model initialization and before training starts to ensure the weights are properly prepared for efficient computation.
One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.
- select_gemm_kernel(x: Tensor) None [source]
Selects the most appropriate GEMM kernel from the available CUTLASS kernels for the binary operation.
This selection is based on the dimensions of the input activation tensor and the layer’s output features. The method evaluates available kernels and chooses the optimal one for the given dimensions, enhancing computational efficiency.
- Parameters:
x (torch.Tensor) – The input activation tensor used to determine the optimal GEMM kernel.
- Returns:
The ID of the selected GEMM kernel which will be used for subsequent operations.
Note
This function is intended to be called during the warmup phase of the model, before actual training or inference begins.
- set_activation(x: Tensor) Tensor [source]
Normalizes input activations using a layer-wise scale factor and adds bias. This method is called during the forward pass to apply preprocessing to the input activations.
- Parameters:
x (torch.Tensor) – The input activations to the binary linear layer.
- Returns:
The normalized and biased activations ready for the binary linear operation.
- Return type:
torch.Tensor
Note
The scale factor scale_a is dynamically initialized based on the input’s statistical properties if it has not been set previously.
- set_weight_data(x: Tensor)[source]
Prepares the weight data for the binary linear operation. This method is an extension of the base class’s method and additionally calls prepare_params to ensure that the weights are properly formatted for efficient computation.
- Parameters:
x (torch.Tensor) – The activation tensor that may influence how weights are prepared.