bitorch_engine.layers.qlinear.nbit.mps.mpq_layer.MPQLinearMlx
- class bitorch_engine.layers.qlinear.nbit.mps.mpq_layer.MPQLinearMlx(*args, **kwargs)[source]
Represents a MPS-compatible implementation of the mixed precision quantized (MPQ) linear layer, inheriting from MPQLinearBase. This class is specifically optimized for MPS devices, supporting operations with quantized weights and activations in a mixed precision format. It uses the Mlx library to perform efficient quantized matrix multiplication on MPS devices.
The layer supports quantization bits for weights (w_bit) of 2, 4, or 8 and fixed activation bit (a_bit) of 16, ensuring compatibility with common hardware accelerators and optimizing performance for deep learning inference tasks on MPS-enabled Apple Devices.
- qweight
Quantized weights of the layer, adhering to specified precision.
- Type:
torch.nn.Parameter
- w_bit
Bit width for weight quantization.
- Type:
int
- a_bit
Bit width for activation quantization, fixed at 16.
- Type:
int
- scales
Scale factors for quantized weights, calculated during parameter preparation.
- Type:
torch.Tensor
- zeros
Zero points for quantized weights, supporting asymmetric quantization.
- Type:
torch.Tensor
- check_parameters()[source]
Validates the quantization parameters to ensure they meet the requirements.
- prepare_params()[source]
Prepares and decompresses quantized parameters for the forward pass. Must be called before performing inference to correctly setup layer parameters.
Methods
Initializes the MPQLinearMlx layer with given arguments and keyword arguments, setting up the layer to use Mlx with mixed precision quantized weights and activations.
Ensures that the quantization bit widths for weights (w_bit) and activations (a_bit) are valid.
Performs the forward pass of the MPQLinearMlx layer using quantized weights and activations.
This method should be executed before the actual forward pass.
Attributes
training
- __init__(*args, **kwargs) None [source]
Initializes the MPQLinearMlx layer with given arguments and keyword arguments, setting up the layer to use Mlx with mixed precision quantized weights and activations.
- check_parameters() None [source]
Ensures that the quantization bit widths for weights (w_bit) and activations (a_bit) are valid. Raises an assertion error if the conditions are not met.
- forward(x: Tensor) Tensor [source]
Performs the forward pass of the MPQLinearMlx layer using quantized weights and activations.
- Parameters:
x (torch.Tensor) – The input tensor with shape (batch size, number of features).
- Returns:
The output tensor resulting from the quantized linear transformation and bias addition.
- Return type:
torch.Tensor
- prepare_params() None [source]
This method should be executed before the actual forward pass. It mainly decompress quantized parameters such as qscale and qzero. This step could be simplified or eliminated in the future by having a kernel implementation that can decompress during kernel computation.
One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.