bitorch_engine.layers.qlinear.nbit.cuda.mbwq_layer.MBWQLinearCuda

class bitorch_engine.layers.qlinear.nbit.cuda.mbwq_layer.MBWQLinearCuda(*args, use_mbw: bool = True, groups=64, rows_packed=64, **kwargs)[source]

Implements a Mixed-BitWidth Quantized (MBWQ) linear layer for CUDA devices. This layer extends the functionality of MPQLinearBase by supporting mixed bit-width quantization schemes to optimize model size and computational efficiency while running on CUDA-enabled hardware. It allows for flexible quantization configurations across different parts of the network, enabling fine-grained control over the trade-off between accuracy and performance.

use_mbw

Flag to enable mixed-bitwidth quantization. When set to True, the layer uses different quantization bit widths for different parts of the weight matrix. Defaults to True.

Type:

bool

qweight

Quantized weights of the layer, stored with specified bit-widths as per use_mbw setting.

Type:

torch.Tensor

rows

A list of 7 elements containing information about the distribution of weights across different bit-widths and kernel permutation information for mixed-bitwidth settings.

Type:

list

scales

Scale factors used for quantization, applicable in mixed-bitwidth mode.

Type:

torch.Tensor

zeros

Zero-point values used for quantization, applicable in mixed-bitwidth mode.

Type:

torch.Tensor

check_parameters()[source]

Validates the layer’s parameters to ensure compatibility with the chosen quantization settings.

load_state_dict()[source]

Custom logic to load the layer’s state dictionary, handling mixed-bitwidth configurations.

set_scales()[source]

Sets the scale factors for quantization, necessary for initializing or updating the quantized model.

set_zeros()[source]

Sets the zero points for quantization, necessary for initializing or updating the quantized model.

prepare_params()[source]

Prepares the layer’s parameters for inference, optimizing memory usage and computational efficiency.

forward()[source]

Defines the forward pass of the layer with quantized weights and possibly mixed bit-width configurations.

q42fp_weight()[source]

reconstructs fp weight from q4-quantized qweight.

exl2fp_weight()[source]

reconstructs fp weight from exl2-quantized qweight.

Methods

__init__

Initializes the MBWQLinearCuda layer with optional mixed-bitwidth quantization.

check_parameters

Validates the layer's parameters against the expected constraints for data type and quantization settings.

exl2fp_weight

Reconstructs full-precision weights from weights that were quantized using an extended level 2 (exl2) quantization scheme.

forward

Defines the forward pass of the MBWQLinearCuda layer with quantized weights and mixed bit-width configurations.

load_state_dict

Custom logic to load the state dictionary, handling special cases for mixed-bitwidth quantization.

prepare_params

This method should be executed before the actual forward pass.

q42fp_weight

Converts quantized weights (qweight) from 4-bit and 2-bit quantization back to full-precision (floating-point) weights.

set_scales

Sets the scale factors for quantization, necessary for initializing or updating the quantized model.

set_zeros

Sets the zero points for quantization, necessary for initializing or updating the quantized model.

Attributes

training

__init__(*args, use_mbw: bool = True, groups=64, rows_packed=64, **kwargs) None[source]

Initializes the MBWQLinearCuda layer with optional mixed-bitwidth quantization.

Parameters:
  • *args – Variable length argument list to be passed to the base class initializer.

  • use_mbw (bool, optional) – Specifies whether to use mixed-bitwidth quantization. Defaults to True.

  • **kwargs – Arbitrary keyword arguments to be passed to the base class initializer.

check_parameters() None[source]

Validates the layer’s parameters against the expected constraints for data type and quantization settings.

static exl2fp_weight(qweight: Tensor, scales: Tensor, zeros: Tensor, q_perm: Tensor, q_group_map: Tensor, rows: list) Tensor[source]

Reconstructs full-precision weights from weights that were quantized using an extended level 2 (exl2) quantization scheme.

This function handles the de-quantization process for weights that were quantized with a more complex, perhaps non-linear, quantization method. It requires multiple parameters including scales, zero points, and mappings that define how the quantized values should be translated back to full-precision values.

Parameters:
  • qweight (torch.Tensor) – The quantized weights tensor.

  • scales (torch.Tensor) – Scale factors for the quantized weights.

  • zeros (torch.Tensor) – The quantized zero points.

  • q_perm (torch.Tensor) – A permutation tensor that specifies the order of quantized weights.

  • q_group_map (torch.Tensor) – A mapping tensor that groups quantized weights.

  • rows (list) – A list specifying the rows (or indices) of weights to be processed.

Returns:

The reconstructed full-precision weights tensor.

Return type:

torch.Tensor

forward(x: Tensor) Tensor[source]

Defines the forward pass of the MBWQLinearCuda layer with quantized weights and mixed bit-width configurations.

Parameters:

x (torch.Tensor) – The input tensor to the layer.

Returns:

The output tensor resulting from applying the quantized linear transformation.

Return type:

torch.Tensor

load_state_dict(state_dict, strict=True) None[source]

Custom logic to load the state dictionary, handling special cases for mixed-bitwidth quantization.

Parameters:
  • state_dict (dict) – The state dictionary to load.

  • strict (bool) – Specifies whether to enforce strict loading of state_dict keys.

prepare_params() None[source]

This method should be executed before the actual forward pass. Prepare for inference from memory management and parameter adaptation.

One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.

static q42fp_weight(qweight: Tensor, scales: Tensor, zeros: Tensor, group_size: int, bits: int, q_perm: Tensor) Tensor[source]

Converts quantized weights (qweight) from 4-bit and 2-bit quantization back to full-precision (floating-point) weights.

This function is used for de-quantizing weights that were previously quantized to 4-bit and 2-bit representation, applying the necessary scales and zero points for accurate reconstruction.

Parameters:
  • qweight (torch.Tensor) – The quantized weights tensor in 4-bit and 2-bit representation.

  • scales (torch.Tensor) – The scale factors associated with each quantized weight for conversion back to full-precision.

  • zeros (torch.Tensor) – The zero points associated with each quantized weight, used during the de-quantization process.

  • group_size (int) – The size of the weight groups that were quantized together. This parameter is crucial for correctly reshaping the tensor during de-quantization.

  • bits (int) – 4- or 2-bits

  • q_perm (torch.Tensor) – A permutation tensor that specifies the order of quantized weights.

Returns:

The de-quantized (full-precision) weights tensor.

Return type:

torch.Tensor

set_scales(scales: Tensor | None = None) None[source]

Sets the scale factors for quantization, necessary for initializing or updating the quantized model.

Parameters:

scales (torch.Tensor, optional) – The scale factors to be applied. If None, no action is taken.

set_zeros(zeros: Tensor | None = None) None[source]

Sets the zero points for quantization, necessary for initializing or updating the quantized model.

Parameters:

zeros (torch.Tensor, optional) – The zero points to be applied. If None, no action is taken.