Source code for bitorch_engine.layers.qlinear.nbit.cuda.utils

import torch
from bitorch_engine.layers.qlinear.nbit import MPQWeightParameter


[docs] def unpack_qweight(qweight: MPQWeightParameter) -> torch.Tensor: """ Reconstructs the fp16 weight tensor from the input quantized weight parameter. Parameters: qweight (MPQWeightParameter): The quantized weight parameter object containing all necessary quantization information. Returns: torch.Tensor: The reconstructed weight tensor in fp16 format. Raises: ValueError: If essential attributes are missing in the input qweight parameter. NotImplementedError: For quantization types that are not yet supported. Supported quantization styles: 1. GPTQ style with g_index. 2. GPTQ style without g_index. 3. Mixed-bit quantization. """ layer_type = getattr(qweight, 'layer_type', None) if layer_type is None: raise ValueError("Error: invalid attribute of qweight in 'unpack_qweight'.") # Process based on layer type if qweight.layer_type == 1: # GPTQ style (with or without g_index) wf = torch.tensor(list(range(0, 32, qweight.w_bit)), dtype=torch.int32, device=qweight.device).unsqueeze(0) weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // qweight.w_bit, -1), wf.unsqueeze(-1)).to(torch.int16 if qweight.w_bit == 8 else torch.int8).view(-1, qweight.size(-1)) torch.bitwise_and(weight, (2 ** qweight.w_bit) - 1, out=weight) if qweight.asym: zeros_unpack = torch.bitwise_right_shift(torch.unsqueeze(qweight.zeros, 2).expand(-1, -1, 32 // qweight.w_bit), wf.unsqueeze(0)).to(torch.int16 if qweight.w_bit == 8 else torch.int8) torch.bitwise_and(zeros_unpack, (2 ** qweight.w_bit) - 1, out=zeros_unpack) zeros_unpack = zeros_unpack + 1 zeros = zeros_unpack.reshape(-1, qweight.size(-1)) weights = qweight.scales[qweight.g_idx.long()] * (weight - zeros[qweight.g_idx.long()]) else: # 2. GPTQ style without g_index. if qweight.g_idx is None: scales = qweight.scales.unsqueeze(1).repeat(1, weight.size(0)//qweight.scales.size(0), 1).view(-1, qweight.scales.size(-1)) zeros = qweight.zeros.unsqueeze(1).repeat(1, weight.size(0) // qweight.zeros.size(0), 1).view(-1, qweight.zeros.size(-1)) weights = weight.mul(scales) - zeros else: weights = weight * qweight.scales[qweight.g_idx.long()] - qweight.zeros[qweight.g_idx.long()] elif qweight.layer_type == 2: # MBWQLinear layer weights = None try: from bitorch_engine.layers.qlinear.nbit.cuda import MBWQLinearCuda use_mbwq = True if qweight.q_group_map is not None else False if not use_mbwq: weights = MBWQLinearCuda.q42fp_weight(qweight.data, qweight.scales, qweight.zeros, qweight.group_size, qweight.w_bit, qweight.q_perm) else: weights = MBWQLinearCuda.exl2fp_weight(qweight.data, qweight.scales, qweight.zeros, qweight.q_perm, qweight.q_group_map, qweight.rows) except ModuleNotFoundError as e: print(f"Error: Module not found: {e}.") else: raise NotImplementedError("Error: 'layer_type' not yet supported!") return weights
[docs] def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter, unpacked_zeros: torch.Tensor = None) -> torch.Tensor: """Packs the fp16 weight into a quantized weight format using the attributes defined in the QweightParameter. This function handles three main scenarios: 1. GPTQ style quantization with group index (g_index). 2. GPTQ style quantization without g_index. 3. Mixed-bit quantization (currently not implemented). Parameters: weight (torch.Tensor): The floating-point weights to be quantized and packed. qweight (MPQWeightParameter): An object containing quantization parameters. Returns: torch.Tensor: The packed integer tensor representing the quantized weights. Raises: ValueError: If 'layer_type' attribute is invalid or not present. NotImplementedError: For unimplemented quantization methods, like mixed-bit quantization. """ layer_type = getattr(qweight, 'layer_type', None) scales = getattr(qweight, 'scales', None) zeros = getattr(qweight, 'zeros', None) w_bit = getattr(qweight, 'w_bit', None) asym = getattr(qweight, 'asym', None) g_idx = getattr(qweight, 'g_idx', None) if layer_type is None: raise ValueError("Error: invalid 'layer_type' attribute in 'unpack_qweight' method.") # Process based on layer_type and existence of q_perm for quantization if layer_type == 1 or (layer_type == 2 and qweight.q_group_map is None): # MPQLinear or MBWQLinear-q4 if asym: # this if-branch is for classical GPTQ-style models if unpacked_zeros is not None: zeros = unpacked_zeros elif zeros.dtype == torch.int32: wf = torch.tensor(list(range(0, 32, w_bit)), dtype=torch.int32, device=qweight.device).unsqueeze(0) zeros_unpack = torch.bitwise_right_shift( torch.unsqueeze(zeros, 2).expand(-1, -1, 32 // w_bit), wf.unsqueeze(0)).to(torch.int16 if w_bit == 8 else torch.int8) torch.bitwise_and(zeros_unpack, (2 ** w_bit) - 1, out=zeros_unpack) zeros_unpack = zeros_unpack + 1 zeros = zeros_unpack.reshape(-1, qweight.size(-1)) else: raise ValueError(f"Error: Got invalid dtype of qweight.zeros while packing fp weight.") intweight = torch.round(weight / scales[g_idx.long()] + zeros[g_idx.long()]).to(torch.int32).clamp(0, 2**w_bit-1) else: if g_idx is None: # Adjust scales and zeros for symmetric quantization without group index scales = scales.unsqueeze(1).repeat(1, weight.size(0)//scales.size(0), 1).view(-1, scales.size(-1)) zeros = zeros.unsqueeze(1).repeat(1, weight.size(0) // zeros.size(0), 1).view(-1, zeros.size(-1)) if hasattr(qweight, "q_perm") and qweight.q_perm is not None: q_perm = qweight.q_perm.unsqueeze(1).repeat(1, weight.size(1)).long() weight = torch.gather(weight, dim=0, index=q_perm) intweight = torch.round((weight + zeros) / scales).to(torch.int32).clamp(0, 2 ** w_bit - 1) else: # Calculate integer weights for symmetric quantization with group index intweight = torch.round((weight + zeros[g_idx.long()]) / scales[g_idx.long()]).to(torch.int32).clamp(0, 2**w_bit-1) # Perform parallel bitpacking wf = torch.tensor(list(range(0, 32, w_bit)), dtype=torch.int32, device=qweight.device).unsqueeze(0) intweight = torch.sum( torch.bitwise_left_shift( intweight.reshape(-1, 32 // w_bit, intweight.size(-1)), wf.unsqueeze(-1) ), dim=1, dtype=torch.int32 ) else: # TODO: Placeholder for channel-mix quantization method raise NotImplementedError("Error: pack_fp_weight for MBWQLinear using channel-mix quantization not supported yet.") return intweight.to(torch.int32)
[docs] def make_group_map(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor: """ Creates a mapping of quantization groups for handling irregular group sizes in quantized models. This function generates a tensor representing the mapping of groups, where each group might have a different size due to the quantization process. The mapping is used to organize or access quantized weights or parameters based on their group assignment. Parameters: q_groups (torch.Tensor): A tensor containing information about the quantization groups. It is expected to hold pairs of values, where each pair consists of 'bits' and 'start index' for each group. num_qrows (int): The total number of quantization rows, representing the overall size of the quantization dimension. Returns: torch.Tensor: A tensor of short integers representing the group mapping. Each group is represented by its index followed by the inverse row index within the group. Example: Given q_groups tensor indicating group sizes and num_qrows indicating the total quantization rows, this function calculates the group mapping required for accessing or organizing the quantized parameters. """ num_groups = q_groups.numel() // 2 group_map = [] for i in range(num_groups): bits = q_groups[i * 2] if i < num_groups - 1: qrows = q_groups[i * 2 + 3] - q_groups[i * 2 + 1] else: qrows = num_qrows - q_groups[i * 2 + 1] rows = qrows * 32 // bits for j in range(rows): group_map.append(i) group_map.append(rows - j) return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)