bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMul
- class bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMul(dtype=torch.float32, device=None, *args, **kwargs)[source]
A custom PyTorch module for performing quantized matrix multiplication, specifically designed for 4-bit quantization. This module quantizes inputs before multiplication based on a dynamic scaling factor and aims to maintain high precision in low-bitwidth computations.
- device
The device on which tensors will be allocated.
- Type:
torch.device
- dtype
The data type for the parameters and outputs.
- Type:
torch.dtype
- x_clip
The dynamic scaling factor for the first input tensor.
- Type:
torch.nn.Parameter
- y_clip
The dynamic scaling factor for the second input tensor.
- Type:
torch.nn.Parameter
- eps
A small epsilon value to prevent division by zero in computations.
- Type:
torch.Tensor
- Parameters:
dtype (torch.dtype) – The desired data type for computations (default: torch.float).
device (torch.device, optional) – The device on which to perform computations.
*args – Variable length argument list for the parent class.
**kwargs – Arbitrary keyword arguments for the parent class.
Methods
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Forward pass of the Q4MatMul module.
Dynamically sets the scaling factors for input tensors x and y, based on their respective values.
Attributes
training
- __init__(dtype=torch.float32, device=None, *args, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, y: Tensor) Tensor [source]
Forward pass of the Q4MatMul module. Validates input dimensions, sets activation scales, and performs quantized matrix multiplication.
- Parameters:
x (torch.Tensor) – The first input tensor.
y (torch.Tensor) – The second input tensor.
- Returns:
The result of the quantized matrix multiplication.
- Return type:
torch.Tensor
- Raises:
AssertionError – If the input tensors do not have more than two dimensions.
- set_activation_scale(x: Tensor, y: Tensor) None [source]
Dynamically sets the scaling factors for input tensors x and y, based on their respective values. This scaling helps in quantizing the activations for multiplication.
The scale is calculated as: alpha = 2 * tensor.abs().mean() / math.sqrt(Qp), where Qp = 127 for 8-bit quantization, adjusted here for 4-bit quantization.
- Parameters:
x (torch.Tensor) – The first input tensor for matrix multiplication.
y (torch.Tensor) – The second input tensor for matrix multiplication.