bitorch_engine.layers.qlinear.nbit.layer.nBitLinearBase
- class bitorch_engine.layers.qlinear.nbit.layer.nBitLinearBase(in_channels: int, out_channels: int, a_bit: int = 4, w_bit: int = 4, device=None, dtype=torch.float32)[source]
A base class for n-bit Quantization-Aware Training (QAT) linear layers. This class provides a framework for implementing layers that operate with low-bitwidth activations and weights during training, and supports quantization for efficient inference. It maintains both floating-point and quantized weights to facilitate the QAT process.
- in_channels
The dimension of input features after bit-packing, indicating the number of input features to the layer.
- Type:
int
- out_channels
The dimension of output features, indicating the number of output features produced by the layer.
- Type:
int
- a_bit
The bit-width for activations used during training. Defaults to 4 bits.
- Type:
int
- w_bit
The bit-width for weights used during training and inference. Defaults to 4 bits.
- Type:
int
- device
The device on which the layer’s parameters are stored. Defaults to None, which means the default device is used.
- dtype
The data type for the layer’s parameters. Defaults to torch.float.
Note
This class is designed to be subclassed by specific implementations of n-bit linear layers, which should provide mechanisms for parameter preparation (prepare_params), weight quantization (generate_quantized_weight), and other necessary operations.
Methods
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Generates and sets the quantized weights based on the current floating-point weights.
Prepares and initializes the model parameters for training.
Initializes or resets the floating-point weights of the layer using Kaiming uniform initialization.
Sets the quantized weights of the layer to the provided tensor.
Sets the floating-point weights of the layer to the provided tensor.
Attributes
Returns the optimal weight for the current mode (training or inference).
training
- __init__(in_channels: int, out_channels: int, a_bit: int = 4, w_bit: int = 4, device=None, dtype=torch.float32) None [source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- generate_quantized_weight(qweight_only: bool = False) None [source]
Generates and sets the quantized weights based on the current floating-point weights. This method must be implemented by subclasses and is crucial for converting floating-point weights to low-bitwidth quantized weights for inference.
- Parameters:
qweight_only (bool) – If True, only quantized weights are generated.
- property opt_weight: Parameter
Returns the optimal weight for the current mode (training or inference). If the model is in inference mode and quantized weights are not yet generated, it triggers quantized weight generation.
- Returns:
The current optimal weights (floating-point for training, quantized for inference).
- Return type:
torch.nn.Parameter
- prepare_params() None [source]
Prepares and initializes the model parameters for training.
Note
This method MUST be called after model initialization and before training starts to ensure the weights are properly prepared for efficient computation.
- reset_parameters() None [source]
Initializes or resets the floating-point weights of the layer using Kaiming uniform initialization.