bitorch_engine.layers.qlinear.nbit.cuda.mbwq_layer.MBWQLinearCuda
- class bitorch_engine.layers.qlinear.nbit.cuda.mbwq_layer.MBWQLinearCuda(*args, use_mbw: bool = True, groups=64, rows_packed=64, **kwargs)[source]
Implements a Mixed-BitWidth Quantized (MBWQ) linear layer for CUDA devices. This layer extends the functionality of MPQLinearBase by supporting mixed bit-width quantization schemes to optimize model size and computational efficiency while running on CUDA-enabled hardware. It allows for flexible quantization configurations across different parts of the network, enabling fine-grained control over the trade-off between accuracy and performance.
- use_mbw
Flag to enable mixed-bitwidth quantization. When set to True, the layer uses different quantization bit widths for different parts of the weight matrix. Defaults to True.
- Type:
bool
- qweight
Quantized weights of the layer, stored with specified bit-widths as per use_mbw setting.
- Type:
torch.Tensor
- rows
A list of 7 elements containing information about the distribution of weights across different bit-widths and kernel permutation information for mixed-bitwidth settings.
- Type:
list
- scales
Scale factors used for quantization, applicable in mixed-bitwidth mode.
- Type:
torch.Tensor
- zeros
Zero-point values used for quantization, applicable in mixed-bitwidth mode.
- Type:
torch.Tensor
- check_parameters()[source]
Validates the layer’s parameters to ensure compatibility with the chosen quantization settings.
- load_state_dict()[source]
Custom logic to load the layer’s state dictionary, handling mixed-bitwidth configurations.
- set_scales()[source]
Sets the scale factors for quantization, necessary for initializing or updating the quantized model.
- set_zeros()[source]
Sets the zero points for quantization, necessary for initializing or updating the quantized model.
- prepare_params()[source]
Prepares the layer’s parameters for inference, optimizing memory usage and computational efficiency.
- forward()[source]
Defines the forward pass of the layer with quantized weights and possibly mixed bit-width configurations.
Methods
Initializes the MBWQLinearCuda layer with optional mixed-bitwidth quantization.
Validates the layer's parameters against the expected constraints for data type and quantization settings.
Reconstructs full-precision weights from weights that were quantized using an extended level 2 (exl2) quantization scheme.
Defines the forward pass of the MBWQLinearCuda layer with quantized weights and mixed bit-width configurations.
Custom logic to load the state dictionary, handling special cases for mixed-bitwidth quantization.
This method should be executed before the actual forward pass.
Converts quantized weights (qweight) from 4-bit and 2-bit quantization back to full-precision (floating-point) weights.
Sets the scale factors for quantization, necessary for initializing or updating the quantized model.
Sets the zero points for quantization, necessary for initializing or updating the quantized model.
Attributes
training
- __init__(*args, use_mbw: bool = True, groups=64, rows_packed=64, **kwargs) None [source]
Initializes the MBWQLinearCuda layer with optional mixed-bitwidth quantization.
- Parameters:
*args – Variable length argument list to be passed to the base class initializer.
use_mbw (bool, optional) – Specifies whether to use mixed-bitwidth quantization. Defaults to True.
**kwargs – Arbitrary keyword arguments to be passed to the base class initializer.
- check_parameters() None [source]
Validates the layer’s parameters against the expected constraints for data type and quantization settings.
- static exl2fp_weight(qweight: Tensor, scales: Tensor, zeros: Tensor, q_perm: Tensor, q_group_map: Tensor, rows: list) Tensor [source]
Reconstructs full-precision weights from weights that were quantized using an extended level 2 (exl2) quantization scheme.
This function handles the de-quantization process for weights that were quantized with a more complex, perhaps non-linear, quantization method. It requires multiple parameters including scales, zero points, and mappings that define how the quantized values should be translated back to full-precision values.
- Parameters:
qweight (torch.Tensor) – The quantized weights tensor.
scales (torch.Tensor) – Scale factors for the quantized weights.
zeros (torch.Tensor) – The quantized zero points.
q_perm (torch.Tensor) – A permutation tensor that specifies the order of quantized weights.
q_group_map (torch.Tensor) – A mapping tensor that groups quantized weights.
rows (list) – A list specifying the rows (or indices) of weights to be processed.
- Returns:
The reconstructed full-precision weights tensor.
- Return type:
torch.Tensor
- forward(x: Tensor) Tensor [source]
Defines the forward pass of the MBWQLinearCuda layer with quantized weights and mixed bit-width configurations.
- Parameters:
x (torch.Tensor) – The input tensor to the layer.
- Returns:
The output tensor resulting from applying the quantized linear transformation.
- Return type:
torch.Tensor
- load_state_dict(state_dict, strict=True) None [source]
Custom logic to load the state dictionary, handling special cases for mixed-bitwidth quantization.
- Parameters:
state_dict (dict) – The state dictionary to load.
strict (bool) – Specifies whether to enforce strict loading of state_dict keys.
- prepare_params() None [source]
This method should be executed before the actual forward pass. Prepare for inference from memory management and parameter adaptation.
One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.
- static q42fp_weight(qweight: Tensor, scales: Tensor, zeros: Tensor, group_size: int, bits: int, q_perm: Tensor) Tensor [source]
Converts quantized weights (qweight) from 4-bit and 2-bit quantization back to full-precision (floating-point) weights.
This function is used for de-quantizing weights that were previously quantized to 4-bit and 2-bit representation, applying the necessary scales and zero points for accurate reconstruction.
- Parameters:
qweight (torch.Tensor) – The quantized weights tensor in 4-bit and 2-bit representation.
scales (torch.Tensor) – The scale factors associated with each quantized weight for conversion back to full-precision.
zeros (torch.Tensor) – The zero points associated with each quantized weight, used during the de-quantization process.
group_size (int) – The size of the weight groups that were quantized together. This parameter is crucial for correctly reshaping the tensor during de-quantization.
bits (int) – 4- or 2-bits
q_perm (torch.Tensor) – A permutation tensor that specifies the order of quantized weights.
- Returns:
The de-quantized (full-precision) weights tensor.
- Return type:
torch.Tensor