import math
from torch import nn
import torch
from torch.nn import init
from bitorch_engine.utils.model_helper import qweight_update_fn
[docs]
class MPQWeightParameter(nn.Parameter):
"""
A custom parameter class for quantized weights, extending torch.nn.Parameter,
with additional attributes specific to quantization.
Attributes:
privileged_grad: Optional tensor for privileged gradients (not used in standard backpropagation).
scales, zeros: Quantization scales and zero points for the affine quantization.
g_idx: Group index for weight quantization.
w_bit: Bit-width for weight quantization.
asym: Flag to indicate if asymmetric quantization is used.
group_size: The size of quantization groups.
layer_type: Type of layer (e.g., MPQLinear: 1, MBWQLinear: 2).
q_perm: Permutation indices for quantization groups.
qscales_zeros, qscales_scales, qzeros_zeros, qzeros_scales: Additional quantization parameters for calculating
(q)scales and (q)zeros.
q_group_map: Mapping from weights to quantization groups.
rows: Storing rows information for each bit-width in the quantized weight matrix.
Parameters:
data (Tensor, optional): Parameter tensor.
requires_grad (bool, optional): If the parameter requires gradient. Default: True.
The rest of the parameters are specific to the quantization process and are optional.
"""
def __new__(cls, data=None,
requires_grad: bool = True,
privileged_grad: torch.Tensor = None,
scales: torch.Tensor = None,
zeros: torch.Tensor = None,
g_idx: torch.Tensor = None,
w_bit: int = -1,
asym: bool = False,
group_size: int = -1,
layer_type: int = -1,
q_perm: torch.Tensor = None,
qscales_zeros: torch.Tensor = None,
qscales_scales: torch.Tensor = None,
qzeros_zeros: torch.Tensor = None,
qzeros_scales: torch.Tensor = None,
q_group_map: torch.Tensor = None,
rows: list = None
):
return super().__new__(cls, data, requires_grad=requires_grad)
[docs]
def __init__(self, data: torch.Tensor=None,
requires_grad: bool=True,
privileged_grad: torch.Tensor=None,
scales: torch.Tensor=None,
zeros: torch.Tensor=None,
g_idx: torch.Tensor=None,
w_bit: int=-1,
asym: bool=False,
group_size: int=-1,
layer_type: int=-1,
q_perm: torch.Tensor=None,
qscales_zeros: torch.Tensor=None,
qscales_scales: torch.Tensor=None,
qzeros_zeros: torch.Tensor=None,
qzeros_scales: torch.Tensor=None,
q_group_map: torch.Tensor=None,
rows: list=None):
self.privileged_grad = privileged_grad
self.scales = scales
self.zeros = zeros
self.g_idx = g_idx
self.w_bit = w_bit
self.asym = asym
self.group_size = group_size
self.layer_type = layer_type # layer_type: MPQLinear: 1, MBWQLinear: 2
self.q_perm = q_perm
self.qscales_zeros = qscales_zeros
self.qscales_scales = qscales_scales
self.qzeros_zeros = qzeros_zeros
self.qzeros_scales = qzeros_scales
self.q_group_map = q_group_map
self.rows = rows
[docs]
@staticmethod
def update(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None, exp_avg_l: torch.Tensor=None,
step: torch.Tensor=None, lr:float=1e-4, weight_decay:float=0.0, beta1:float=0.99,
beta2:float=0.9999, eps: float = 1e-6, dtype=torch.half, correct_bias=None, projector=None,
grad:torch.Tensor=None) -> None:
"""
This method defines how to update quantized weights with quantized gradients.
It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.
Args:
qweight (torch.nn.Parameter): The current quantized weight parameter to be updated.
exp_avg_s (torch.Tensor, optional): Exponential moving average of squared gradients. Used in optimization algorithms like Adam.
exp_avg_l (torch.Tensor, optional): Exponential moving average of the gradients. Also used in optimizers like Adam.
step (torch.Tensor, optional): The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.
lr (float, optional): Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay (float, optional): Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.
beta1 (float, optional): The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.
beta2 (float, optional): The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.
eps (float, optional): A small constant for numerical stability.
dtype (torch.dtype, optional): The data type to be used for computations.
correct_bias (optional): Whether to apply bias correction (specific to certain models like BERT).
projector (optinal): Whether use a gradient projector.
grad (optional): gradient tensor will be used if projector used.
Returns:
None: The function is expected to update the `qweight` in-place and does not return anything.
Raises:
NotImplementedError: Indicates that the function has not yet been implemented.
"""
assert isinstance(qweight, MPQWeightParameter), 'Error: the type of qweight must be ' \
'MPQWeightParameter. '
qweight_update_fn(qweight=qweight, exp_avg_s=exp_avg_s, exp_avg_l=exp_avg_l,
step=step, lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2,
correct_bias=correct_bias, eps=eps, dtype=dtype, projector=projector, grad=grad)
[docs]
class nBitLinearParameter(torch.nn.Parameter):
"""
A custom parameter class for n-bit linear layer, extending torch.nn.Parameter.
This class is designed to support n-bit linear layers, particularly useful
in models requiring efficient memory usage and specialized optimization techniques.
Args:
data (torch.Tensor, optional): The initial data for the parameter. Defaults to None.
requires_grad (bool, optional): Flag indicating whether gradients should be computed
for this parameter in the backward pass. Defaults to True.
"""
def __new__(cls,
data: torch.Tensor=None,
requires_grad: bool=True
):
return super().__new__(cls, data=data, requires_grad=requires_grad)
[docs]
@staticmethod
def update(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None, exp_avg_l: torch.Tensor=None,
step: torch.Tensor=None, lr:float=1e-4, weight_decay:float=0.0, beta1:float=0.99,
beta2:float=0.9999, eps: float = 1e-6, dtype=torch.half, correct_bias=None, projector=None,
grad:torch.Tensor=None) -> None:
"""
This method defines how to update quantized weights with quantized gradients.
It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.
Args:
qweight (torch.nn.Parameter): The current quantized weight parameter to be updated.
exp_avg_s (torch.Tensor, optional): Exponential moving average of squared gradients. Used in optimization algorithms like Adam.
exp_avg_l (torch.Tensor, optional): Exponential moving average of the gradients. Also used in optimizers like Adam.
step (torch.Tensor, optional): The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.
lr (float, optional): Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay (float, optional): Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.
beta1 (float, optional): The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.
beta2 (float, optional): The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.
eps (float, optional): A small constant for numerical stability.
dtype (torch.dtype, optional): The data type to be used for computations.
correct_bias (optional): Whether to apply bias correction (specific to certain models like BERT).
projector (optinal): Whether use a gradient projector.
grad (optional): gradient tensor will be used if projector used.
Returns:
None: The function is expected to update the `qweight` in-place and does not return anything.
Raises:
NotImplementedError: Indicates that the function has not yet been implemented.
"""
assert isinstance(qweight, nBitLinearParameter), 'Error: the type of qweight must be ' \
'nBitLinearParameter. '
qweight_update_fn(qweight=qweight, exp_avg_s=exp_avg_s, exp_avg_l=exp_avg_l,
step=step, lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2,
correct_bias=correct_bias, eps=eps, dtype=dtype, projector=projector, grad=grad)
[docs]
class nBitLinearBase(nn.Module):
"""
A base class for n-bit Quantization-Aware Training (QAT) linear layers. This class provides a framework for
implementing layers that operate with low-bitwidth activations and weights during training, and supports
quantization for efficient inference. It maintains both floating-point and quantized weights to facilitate
the QAT process.
Attributes:
in_channels (int): The dimension of input features after bit-packing, indicating the number of input features to the layer.
out_channels (int): The dimension of output features, indicating the number of output features produced by the layer.
a_bit (int): The bit-width for activations used during training. Defaults to 4 bits.
w_bit (int): The bit-width for weights used during training and inference. Defaults to 4 bits.
device: The device on which the layer's parameters are stored. Defaults to `None`, which means the default device is used.
dtype: The data type for the layer's parameters. Defaults to `torch.float`.
Note:
This class is designed to be subclassed by specific implementations of n-bit linear layers, which should
provide mechanisms for parameter preparation (`prepare_params`), weight quantization (`generate_quantized_weight`),
and other necessary operations.
"""
[docs]
def __init__(self, in_channels: int,
out_channels: int,
a_bit: int = 4,
w_bit: int = 4,
device=None,
dtype=torch.float) -> None:
super(nBitLinearBase, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.device = device
self.dtype = dtype
self.a_bit = a_bit
self.w_bit = w_bit
self.weight = None
self.qweight = None
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
"""
Initializes or resets the floating-point weights of the layer using Kaiming uniform initialization.
"""
self.weight = torch.nn.Parameter(
torch.Tensor(self.out_channels, self.in_channels))
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
[docs]
def set_weight_data(self, x: torch.Tensor) -> None:
"""
Sets the floating-point weights of the layer to the provided tensor.
Args:
x (torch.Tensor): The tensor to set as the new weights.
"""
self.weight = nn.Parameter(x, requires_grad=False)
[docs]
def prepare_params(self) -> None:
"""
Prepares and initializes the model parameters for training.
Note:
This method MUST be called after model initialization and before training starts to ensure the weights are
properly prepared for efficient computation.
"""
raise NotImplementedError("Subclasses should implement this method.")
[docs]
def set_quantized_weight_data(self, x: torch.Tensor) -> None:
"""
Sets the quantized weights of the layer to the provided tensor.
Args:
x (torch.Tensor): The tensor to set as the new quantized weights.
"""
self.qweight = nn.Parameter(x, requires_grad=False)
[docs]
def generate_quantized_weight(self, qweight_only: bool = False) -> None:
"""
Generates and sets the quantized weights based on the current floating-point weights. This method must be
implemented by subclasses and is crucial for converting floating-point weights to low-bitwidth quantized weights
for inference.
Args:
qweight_only (bool): If `True`, only quantized weights are generated.
"""
raise NotImplementedError("Subclasses should implement this method.")
def _check_forward(self, x: torch.Tensor) -> None:
"""
A placeholder method for checking the inputs to the forward pass. This method must be implemented by subclasses
to ensure the input tensor is suitable for processing by the layer.
"""
raise NotImplementedError("Subclasses should implement this method.")
@property
def opt_weight(self) -> torch.nn.Parameter:
"""
Returns the optimal weight for the current mode (training or inference). If the model is in inference mode
and quantized weights are not yet generated, it triggers quantized weight generation.
Returns:
torch.nn.Parameter: The current optimal weights (floating-point for training, quantized for inference).
"""
if not self.training and self.qweight is None:
self.generate_quantized_weight()
return self.weight if self.training else self.qweight
[docs]
class MPQLinearBase(nn.Module):
"""
Base class for mixed precision quantized (MPQ) linear layers, designed to support
the computational needs of large language models (LLMs) with mixed precision quantization,
such as 16-bit activations and 4-bit weights for efficient inference. It introduces optimized
computation for bitwise unpacking of quantized weights and 16-bit floating-point matrix multiplication,
tailored for various hardware platforms.
Different to nBitLinearBase, MPQLinearBase serves as the base class for mixed precision quantized linear layers.
This special class is mainly to support the mixed precision linear layer in the current LLMs model,
such as using 16-bit activation and 4-bit quantization weight for inference.
During the reasoning process, two main calculation processes are introduced, namely bitwise unpacking of qweight
from lower bits to 16-bit float, and 16-bit matrix multiplication. Correspondingly, the performance of these
two processes has been optimized on different hardware.
Attributes:
in_channels (int): The number of input features after bit-packing, representing the dimensionality
of the input space.
out_channels (int): The number of output features, representing the dimensionality of the output space.
a_bit (int): The bit-width used for activation quantization, defaulting to 16 bits for high precision.
w_bit (int): The bit-width used for weight quantization, aiming to reduce memory footprint and computational cost.
dtype (torch.dtype): The data type for computations within this layer, typically torch.half for efficiency.
group_size (int): The grouping size for quantization, affecting scale and zero-point calculation.
A value of -1 indicates that the entire input width is treated as one group.
use_gba_quant (bool): Flag to indicate the use of GBA-specific quantization techniques over GPTQ-compliant methods.
dq_group_size (int): Double quantization group size, specific to GBA quantization, for further granularity in quantization.
dq_mode (int): Double quantization mode, catering to different versions and requirements of LLaMA models.
disable_bias (bool): Whether to include a bias term in the linear calculation. Disabling can reduce parameters and computation.
asym (bool): Indicates whether asymmetric quantization is used, offering an alternative to symmetric quantization strategies.
Methods:
initialize(): Initializes parameters and quantization buffers based on the selected quantization method.
init_gptq(): Sets up parameters specific to GPTQ quantization.
init_gba(): Configures buffers and scales for GBA quantization, accommodating for asymmetry and double quantization modes.
set_qweight_data(data): Updates the quantized weight tensor with new data.
generate_quantized_weight(): Placeholder for weight quantization method, to be implemented by subclasses.
check_parameters(): Placeholder for parameter validation, ensuring correct layer configuration.
prepare_params(): Prepares quantized parameters for the forward pass, potentially decompressing quantized values.
"""
[docs]
def __init__(self,
in_channels: int,
out_channels: int,
a_bit: int = 16,
w_bit: int = 4,
dtype=torch.half,
group_size=-1,
use_gba_quant=True,
dq_group_size=-1,
dq_mode=2,
disable_bias=True,
asym=False,
requires_grad=True) -> None:
"""
Args:
in_channels (int): dim of input features after bit-packing
out_channels (int): dim of hidden states
a_bit: activation bits
w_bit: weight bits
dtype: data type used in this layer
group_size: number of associated weight elements->scale and zero facter
disable_bias: whether use bias
use_gba_quant: True: GBA specific quantization, False: use GPTQ-compliant methods
dq_group_size: gba specific parameter. Indicates double quantization group size.
dq_mode: gba specific parameter. Indicates double quantization mode, which is used to adapt to multiple different LLaMA versions.
asym: gba specific parameter. Indicates asymmetry or symmetry quantization strategies.
requires_grad (bool): Indicates whether gradient calculation should be enabled for the parameters.
"""
super(MPQLinearBase, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.a_bit = a_bit
self.w_bit = w_bit
self.dtype = dtype
self.maxq = 2 ** self.w_bit - 1
self.group_size = group_size if group_size > -1 else self.in_channels
self.asym = asym
self.disable_bias = disable_bias
self.use_gba_quant = use_gba_quant
self.dq_group_size = dq_group_size
self.requires_grad = requires_grad
self.dq_mode = dq_mode
self.initialize()
[docs]
def initialize(self) -> None:
"""
Initializes layer parameters and quantization buffers. This method sets up the infrastructure
for either GBA or GPTQ quantization methods, based on the layer configuration. It allocates
memory for quantized weights, scales, zero-points, and other necessary buffers, ensuring
they are ready for the quantization process.
"""
## Both GPTQ and GBA methods require qweights, scales, zeros and group_index
# trainable params
self.qweight = MPQWeightParameter(
data = torch.empty((self.in_channels // 32 * self.w_bit, self.out_channels),
dtype=torch.int32),
requires_grad=self.requires_grad,
w_bit=self.w_bit,
asym=self.asym,
group_size=self.group_size,
)
self.privileged_grad = torch.empty((self.in_channels, self.out_channels), dtype=self.dtype) if self.requires_grad else None
# non-trainable params
self.register_buffer('g_idx', torch.tensor([i // self.group_size for i in range(self.in_channels)],
dtype=torch.int32))
# still need to initialize for the checkpoint loading. Unused params will be released in "prepare_params()"
self.register_buffer('bias', torch.zeros((self.out_channels), dtype=self.dtype))
self.register_buffer("wf", torch.tensor(list(range(0, 32, self.w_bit)), dtype=torch.int32).unsqueeze(0))
# init
if self.use_gba_quant:
self.init_gba()
else:
self.init_gptq()
[docs]
def init_gptq(self) -> None:
"""
Initializes parameters and buffers specific to the GPTQ quantization method. This includes
setting up zero-point buffers, scale factors, and ensuring asymmetric quantization is enabled.
GPTQ, being a more general quantization approach, requires specific buffers to hold quantization
parameters for accurate computation and minimal precision loss.
"""
self.register_buffer('qzeros', torch.zeros((math.ceil(self.in_channels / self.group_size),
self.out_channels // 32 * self.w_bit), dtype=torch.int32))
self.register_buffer('scales', torch.ones((math.ceil(self.in_channels / self.group_size),
self.out_channels), dtype=self.dtype))
self.asym = True
[docs]
def init_gba(self) -> None:
"""
Prepares the layer for GBA-specific quantization, configuring buffers for scales, zero-points,
and statistics for double quantization if enabled. GBA quantization allows for fine-tuned control
over the quantization process, accommodating asymmetric quantization and providing additional
parameters to adjust for different model versions and requirements.
"""
# non-trainable params
if self.dq_group_size == -1:
self.dq_group_size = self.out_channels
buffer_shape_1 = (
math.ceil(self.in_channels / self.group_size),
math.ceil(self.out_channels / self.dq_group_size),
1
)
buffer_shape_2 = (
math.ceil(self.in_channels / self.group_size),
math.ceil(self.out_channels / self.dq_group_size),
self.dq_group_size
)
if self.asym:
self.register_buffer('qzeros', torch.zeros((math.ceil(self.in_channels / self.group_size),
self.out_channels // 32 * self.w_bit), dtype=torch.int32))
if self.w_bit == 4:
self.register_buffer('qscales', torch.ones(buffer_shape_2, dtype=torch.uint8))
else:
self.register_buffer('qscales', torch.ones((math.ceil(self.in_channels / self.group_size),
self.out_channels), dtype=torch.uint8))
else:
self.register_buffer('qstatistic', torch.ones(buffer_shape_2, dtype=torch.uint8))
self.register_buffer('qzeros_zeros', torch.zeros(buffer_shape_1, dtype=self.dtype))
self.register_buffer('qzeros_scales', torch.ones(buffer_shape_1, dtype=self.dtype))
if self.dq_mode == 1:
self.register_buffer('qscales_zeros', torch.zeros((1, self.out_channels, 1), dtype=self.dtype))
self.register_buffer('qscales_scales', torch.ones((1, self.out_channels, 1), dtype=self.dtype))
else:
self.register_buffer('qscales_zeros', torch.zeros(buffer_shape_1, dtype=self.dtype))
self.register_buffer('qscales_scales', torch.ones(buffer_shape_1, dtype=self.dtype))
self.register_buffer('scales', torch.ones((math.ceil(self.in_channels / self.group_size), self.out_channels),
dtype=self.dtype))
self.register_buffer('zeros', torch.zeros((math.ceil(self.in_channels / self.group_size), self.out_channels),
dtype=self.dtype))
[docs]
def set_qweight_data(self, data: torch.Tensor) -> None:
"""
Updates the quantized weight tensor with new data. This method is crucial for adjusting the quantized
weights based on training or fine-tuning processes, ensuring the layer's weights reflect the most
recent updates.
Args:
data (torch.Tensor): The new quantized weight data to be set in the layer.
"""
self.qweight.data = data
[docs]
def generate_quantized_weight(self, qweight_only: bool = False) -> None:
"""
A placeholder method for the weight quantization process. Subclasses should implement this method
to define how the layer's weights are quantized based on the current configuration and quantization
method. This operation is typically executed before saving the model weights or performing inference to ensure that the weights
are in the appropriate quantized format.
Args:
qweight_only (bool): A flag to indicate whether only the quantized weights need to be generated,
without considering other quantization parameters like scales or zero-points.
Default is False, which means all relevant quantization parameters are generated.
"""
raise NotImplementedError("this method has not been implemented.")
[docs]
def check_parameters(self) -> None:
"""
Validates the configuration and parameters of the layer to ensure they are set correctly for the
quantization process. This method should check for common configuration errors and ensure that all
required parameters for the selected quantization method are correctly initialized.
Raises:
NotImplementedError: Indicates that the method has not been implemented yet and needs to be
provided by subclasses.
"""
raise NotImplementedError("Subclasses should implement this method.")
[docs]
def prepare_params(self) -> None:
'''
This method should be executed before the actual forward pass. It mainly decompress quantized parameters
such as qscale and qzero. This step could be simplified or eliminated in the future by having a kernel
implementation that can decompress during kernel computation.
One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function.
Note:
This method should be called before executing the forward pass, especially after loading
the model from a checkpoint or before inference to ensure that quantized parameters are
correctly prepared.
Raises:
NotImplementedError: Indicates that the method has not been implemented yet and should be
provided by subclasses.
'''
raise NotImplementedError("Subclasses should implement this method.")