Source code for bitorch_engine.layers.qlinear.nbit.cuda.mpq_layer

import torch
from torch.autograd import Function
import typing
import math

from bitorch_engine.layers.qlinear.nbit import MPQLinearBase
from bitorch_engine.utils.safe_import import import_extension
from bitorch_engine.utils.model_helper import flatten_x, unflatten_x
from bitorch_engine.layers.qlinear.nbit.cuda.utils import unpack_qweight

q_linear_cuda = import_extension("q_linear_cuda")


[docs] class MPQLinearCudaFunction(Function): """ A custom autograd function for mixed-precision quantized (MPQ) linear operations on CUDA. This function supports forward and backward passes of a linear layer with quantized weights, allowing for efficient computation on CUDA devices. It is specifically designed for scenarios where both activations and weights are quantized to different bit-widths for reduced memory footprint and computational efficiency, particularly useful in low-power and memory-constrained environments such as edge devices. The forward pass performs quantized matrix multiplication using custom CUDA kernels, while the backward pass computes gradients with respect to the input and updates the privileged gradient of the quantized weights. """
[docs] @staticmethod def forward(ctx, x: torch.Tensor, qweight: torch.Tensor, a_bit: int, w_bit: int, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor, asym: bool, is_training: bool, privileged_grad: torch.Tensor=None) -> torch.Tensor: """ Forward pass for the MPQ linear operation. Args: ctx: Context object to save information for backward computation. x (torch.Tensor): Input tensor. qweight (torch.Tensor): Quantized weights. a_bit (int): Activation quantization bit-width. w_bit (int): Weight quantization bit-width. scales (torch.Tensor): Quantization scales. zeros (torch.Tensor): Quantization zero points for asymmetric quantization. g_idx (torch.Tensor): Indices for grouped quantization. asym (bool): Flag for asymmetric quantization. is_training (bool): Training mode flag. Returns: torch.Tensor: The result of the quantized linear operation. """ def setup_qweight(): qweight.scales = scales qweight.zeros = zeros qweight.g_idx = g_idx qweight.w_bit = w_bit qweight.asym = asym qweight.layer_type = 1 x, original_shape = flatten_x(x) if x.size(0) > 32: # use pytorch api setup_qweight() # Reconstruct the floating-point weight fp_weight = unpack_qweight(qweight) output = torch.matmul(x, fp_weight) else: output = q_linear_cuda.mpq_forward(x, qweight, scales, zeros, g_idx, a_bit, w_bit, asym) if is_training: qweight.privileged_grad = privileged_grad if qweight.scales is None: setup_qweight() ctx.a_bit = a_bit ctx.save_for_backward(x, qweight) output = unflatten_x(output, original_shape) return output
[docs] @staticmethod @typing.no_type_check def backward(ctx: torch.autograd.function.BackwardCFunction, output_gradient: torch.Tensor) -> typing.Tuple[torch.Tensor, ...]: """ Backward pass for the MPQ linear operation. Computes gradients with respect to the input and updates the privileged gradient of the quantized weights. Args: ctx: Context object with saved tensors from the forward pass. output_gradient (torch.Tensor): Gradient of the loss with respect to the output of this operation. Note: A specific optimizer bitorch_engine.optim.mpq_adamw is available in conjunction with this layer. Returns: tuple: A tuple containing gradients with respect to the input and None placeholders for other arguments that do not require gradients. """ output_gradient, shape = flatten_x(output_gradient) input, qweight = ctx.saved_tensors w_bit = qweight.w_bit a_bit = ctx.a_bit asym = qweight.asym if qweight.requires_grad: # This additional check is required by peft training. assert qweight.privileged_grad != None, f"The previledge gradient of qweight can not be None in backward pass." output_gradient = output_gradient.to(input.dtype) # to fp16 or fb16 and (m, n) -> (n, m) #==================================================================# ## grad_input = output_gradient.mm(weight) # (m, n)*(n, k)=(m, k) # (k/32*w_bit, n) * (n, m) = (k, m) -> (m, k) grad_input = q_linear_cuda.mpq_grad_input(qweight.data, qweight.scales, qweight.zeros, qweight.g_idx, output_gradient, a_bit, w_bit, asym) #==================================================================# if qweight.requires_grad: # This additional check is required by peft training. qweight.privileged_grad = input.t().mm(output_gradient) # (k, m) * (m, n) = (k, n) grad_input = unflatten_x(grad_input, shape) return grad_input, qweight, None, None, None, None, None, None, None, None
[docs] class MPQLinearCuda(MPQLinearBase): """ Represents a CUDA-compatible implementation of the mixed precision quantized (MPQ) linear layer, inheriting from MPQLinearBase. This class is specifically optimized for CUDA devices, supporting operations with quantized weights and activations in a mixed precision format. The layer supports quantization bits for weights (w_bit) of 2, 4, or 8 and fixed activation bit (a_bit) of 16, ensuring compatibility with common hardware accelerators and optimizing performance for deep learning inference tasks on CUDA-enabled GPUs. Attributes: qweight (torch.nn.Parameter): Quantized weights of the layer, adhering to specified precision. w_bit (int): Bit width for weight quantization. a_bit (int): Bit width for activation quantization, fixed at 16. scales (torch.Tensor): Scale factors for quantized weights, calculated during parameter preparation. zeros (torch.Tensor): Zero points for quantized weights, supporting asymmetric quantization. Methods: check_parameters: Validates the quantization parameters to ensure they meet the requirements. prepare_params: Prepares and decompresses quantized parameters for the forward pass. Must be called before performing inference to correctly setup layer parameters. forward: Executes the forward pass of the layer using quantized operations. """
[docs] def __init__(self, *args, **kwargs) -> None: """ Initializes the MPQLinearCuda layer with given arguments and keyword arguments, setting up the layer for CUDA execution with mixed precision quantized weights and activations. """ super(MPQLinearCuda, self).__init__(*args, **kwargs) self.qweight.layer_type = 1 self.check_parameters()
[docs] def check_parameters(self) -> None: """ Ensures that the quantization bit widths for weights (w_bit) and activations (a_bit) are valid. Raises an assertion error if the conditions are not met. """ assert self.w_bit in [1, 2, 4, 8], f"The value of w_bit ({self.w_bit}) must be 1, 2, 4 or 8." assert self.a_bit == 16, f"The value of a_bit ({self.a_bit}) must be 16."
[docs] def prepare_params(self) -> None: ''' This method should be executed before the actual forward pass. It mainly decompress quantized parameters such as qscale and qzero. This step could be simplified or eliminated in the future by having a kernel implementation that can decompress during kernel computation. One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function. ''' try: if self.use_gba_quant: if self.group_size < 256: # we don't need double quantization for larger group size buffer_shape = (math.ceil(self.in_channels / self.group_size), self.out_channels) if self.asym: qscales = self.qscales.unsqueeze(-1) if self.w_bit == 2 else self.qscales self.zeros = self.qzeros else: qstatistic = self.qstatistic.to(torch.uint8) qscales = (qstatistic & 0xF0) >> 4 qzeros = qstatistic & 0x0F self.zeros = ((qzeros.to(self.dtype) - self.qzeros_zeros) * self.qzeros_scales).view(buffer_shape) self.scales = ((qscales.to(self.dtype) - self.qscales_zeros) * self.qscales_scales).view(buffer_shape) # release some buffers which will not be used anymore del self.qscales_zeros del self.qscales_scales if self.asym: del self.qscales else: del self.qstatistic del self.qzeros_zeros del self.qzeros_scales else: # gptq self.zeros = self.qzeros # release variables not in use if self.disable_bias: del self.bias # only used for loading checkpoints del self.wf # Manually trigger PyTorch's garbage collector torch.cuda.empty_cache() except Exception as e: raise RuntimeError(f"Error occurred during parameter preparation in MPQLinearCuda layer: {e}")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs the forward pass of the MPQLinearCuda layer using quantized weights and activations. Args: x (torch.Tensor): The input tensor with shape (batch size, number of features). Returns: torch.Tensor: The output tensor resulting from the quantized linear transformation and bias addition. """ data_device = x.device if not all(tensor.device == data_device for tensor in [self.qweight, self.scales, self.zeros, self.g_idx]): raise RuntimeError("Some tensors are not on the correct device, please make sure to move the layer to " "the correct device and call 'finalize_quantized_layers'.") out = MPQLinearCudaFunction.apply(x, self.qweight, self.a_bit, self.w_bit, self.scales, self.zeros, self.g_idx, self.asym, self.training, self.privileged_grad) if not self.disable_bias: out = out + self.bias return out