bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMul

class bitorch_engine.layers.qlinear.nbit.cutlass.q4_layer.Q4MatMul(dtype=torch.float32, device=None, *args, **kwargs)[source]

A custom PyTorch module for performing quantized matrix multiplication, specifically designed for 4-bit quantization. This module quantizes inputs before multiplication based on a dynamic scaling factor and aims to maintain high precision in low-bitwidth computations.

device

The device on which tensors will be allocated.

Type:

torch.device

dtype

The data type for the parameters and outputs.

Type:

torch.dtype

x_clip

The dynamic scaling factor for the first input tensor.

Type:

torch.nn.Parameter

y_clip

The dynamic scaling factor for the second input tensor.

Type:

torch.nn.Parameter

eps

A small epsilon value to prevent division by zero in computations.

Type:

torch.Tensor

Parameters:
  • dtype (torch.dtype) – The desired data type for computations (default: torch.float).

  • device (torch.device, optional) – The device on which to perform computations.

  • *args – Variable length argument list for the parent class.

  • **kwargs – Arbitrary keyword arguments for the parent class.

Methods

__init__

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward

Forward pass of the Q4MatMul module.

set_activation_scale

Dynamically sets the scaling factors for input tensors x and y, based on their respective values.

Attributes

training

__init__(dtype=torch.float32, device=None, *args, **kwargs)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, y: Tensor) Tensor[source]

Forward pass of the Q4MatMul module. Validates input dimensions, sets activation scales, and performs quantized matrix multiplication.

Parameters:
  • x (torch.Tensor) – The first input tensor.

  • y (torch.Tensor) – The second input tensor.

Returns:

The result of the quantized matrix multiplication.

Return type:

torch.Tensor

Raises:

AssertionError – If the input tensors do not have more than two dimensions.

set_activation_scale(x: Tensor, y: Tensor) None[source]

Dynamically sets the scaling factors for input tensors x and y, based on their respective values. This scaling helps in quantizing the activations for multiplication.

The scale is calculated as: alpha = 2 * tensor.abs().mean() / math.sqrt(Qp), where Qp = 127 for 8-bit quantization, adjusted here for 4-bit quantization.

Parameters:
  • x (torch.Tensor) – The first input tensor for matrix multiplication.

  • y (torch.Tensor) – The second input tensor for matrix multiplication.