bitorch_engine.layers.qlinear.nbit.layer.nBitLinearBase

class bitorch_engine.layers.qlinear.nbit.layer.nBitLinearBase(in_channels: int, out_channels: int, a_bit: int = 4, w_bit: int = 4, device=None, dtype=torch.float32)[source]

A base class for n-bit Quantization-Aware Training (QAT) linear layers. This class provides a framework for implementing layers that operate with low-bitwidth activations and weights during training, and supports quantization for efficient inference. It maintains both floating-point and quantized weights to facilitate the QAT process.

in_channels

The dimension of input features after bit-packing, indicating the number of input features to the layer.

Type:

int

out_channels

The dimension of output features, indicating the number of output features produced by the layer.

Type:

int

a_bit

The bit-width for activations used during training. Defaults to 4 bits.

Type:

int

w_bit

The bit-width for weights used during training and inference. Defaults to 4 bits.

Type:

int

device

The device on which the layer’s parameters are stored. Defaults to None, which means the default device is used.

dtype

The data type for the layer’s parameters. Defaults to torch.float.

Note

This class is designed to be subclassed by specific implementations of n-bit linear layers, which should provide mechanisms for parameter preparation (prepare_params), weight quantization (generate_quantized_weight), and other necessary operations.

Methods

__init__

Initialize internal Module state, shared by both nn.Module and ScriptModule.

generate_quantized_weight

Generates and sets the quantized weights based on the current floating-point weights.

prepare_params

Prepares and initializes the model parameters for training.

reset_parameters

Initializes or resets the floating-point weights of the layer using Kaiming uniform initialization.

set_quantized_weight_data

Sets the quantized weights of the layer to the provided tensor.

set_weight_data

Sets the floating-point weights of the layer to the provided tensor.

Attributes

opt_weight

Returns the optimal weight for the current mode (training or inference).

training

__init__(in_channels: int, out_channels: int, a_bit: int = 4, w_bit: int = 4, device=None, dtype=torch.float32) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

generate_quantized_weight(qweight_only: bool = False) None[source]

Generates and sets the quantized weights based on the current floating-point weights. This method must be implemented by subclasses and is crucial for converting floating-point weights to low-bitwidth quantized weights for inference.

Parameters:

qweight_only (bool) – If True, only quantized weights are generated.

property opt_weight: Parameter

Returns the optimal weight for the current mode (training or inference). If the model is in inference mode and quantized weights are not yet generated, it triggers quantized weight generation.

Returns:

The current optimal weights (floating-point for training, quantized for inference).

Return type:

torch.nn.Parameter

prepare_params() None[source]

Prepares and initializes the model parameters for training.

Note

This method MUST be called after model initialization and before training starts to ensure the weights are properly prepared for efficient computation.

reset_parameters() None[source]

Initializes or resets the floating-point weights of the layer using Kaiming uniform initialization.

set_quantized_weight_data(x: Tensor) None[source]

Sets the quantized weights of the layer to the provided tensor.

Parameters:

x (torch.Tensor) – The tensor to set as the new quantized weights.

set_weight_data(x: Tensor) None[source]

Sets the floating-point weights of the layer to the provided tensor.

Parameters:

x (torch.Tensor) – The tensor to set as the new weights.