import math
from torch import nn
import torch
from torch.nn import init
from bitorch_engine.utils.model_helper import qweight_update_fn
[docs]
class nBitConvParameter(torch.nn.Parameter):
"""
A custom parameter class for n-bit conv layer, extending torch.nn.Parameter.
This class is designed to support n-bit conv layers, particularly useful
in models requiring efficient memory usage and specialized optimization techniques.
Args:
data (torch.Tensor, optional): The initial data for the parameter. Defaults to None.
requires_grad (bool, optional): Flag indicating whether gradients should be computed
for this parameter in the backward pass. Defaults to True.
"""
def __new__(cls,
data: torch.Tensor=None,
requires_grad: bool=True
):
return super().__new__(cls, data=data, requires_grad=requires_grad)
[docs]
@staticmethod
def update(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None, exp_avg_l: torch.Tensor=None,
step: torch.Tensor=None, lr:float=1e-4, weight_decay:float=0.0, beta1:float=0.99,
beta2:float=0.9999, eps: float = 1e-6, dtype=torch.half, correct_bias=None, projector=None,
grad:torch.Tensor=None) -> None:
"""
This method defines how to update quantized weights with quantized gradients.
It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.
Args:
qweight (torch.nn.Parameter): The current quantized weight parameter to be updated.
exp_avg_s (torch.Tensor, optional): Exponential moving average of squared gradients. Used in optimization algorithms like Adam.
exp_avg_l (torch.Tensor, optional): Exponential moving average of the gradients. Also used in optimizers like Adam.
step (torch.Tensor, optional): The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.
lr (float, optional): Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay (float, optional): Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.
beta1 (float, optional): The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.
beta2 (float, optional): The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.
eps (float, optional): A small constant for numerical stability.
dtype (torch.dtype, optional): The data type to be used for computations.
correct_bias (optional): Whether to apply bias correction (specific to certain models like BERT).
projector (optinal): Whether use a gradient projector.
grad (optional): gradient tensor will be used if projector used.
Returns:
None: The function is expected to update the `qweight` in-place and does not return anything.
Raises:
NotImplementedError: Indicates that the function has not yet been implemented.
"""
assert isinstance(qweight, nBitConvParameter), 'Error: the type of qweight must be ' \
'nBitConvParameter. '
qweight_update_fn(qweight=qweight, exp_avg_s=exp_avg_s, exp_avg_l=exp_avg_l,
step=step, lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2,
correct_bias=correct_bias, eps=eps, dtype=dtype, projector=projector, grad=grad)
[docs]
class nBitConv2dBase(nn.Module):
[docs]
def __init__(self, in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
a_bit: int = 4,
w_bit: int = 4,
device=None,
dtype=torch.float):
"""
Initializes the nBitConv2dBase module, a base class for creating convolutional layers with n-bit quantized weights.
Args:
in_channels (int): The number of input channels in the convolutional layer.
out_channels (int): The number of output channels in the convolutional layer.
kernel_size (int): The size of the convolutional kernel.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): The padding added to all sides of the input tensor. Defaults to 0.
dilation (int, optional): The spacing between kernel elements. Defaults to 1.
a_bit (int, optional): The bit-width for activation quantization. Defaults to 4.
w_bit (int, optional): The bit-width for weight quantization. Defaults to 4.
device (optional): The device on which the module will be allocated. Defaults to None.
dtype (optional): The desired data type of the parameters. Defaults to torch.float.
"""
super(nBitConv2dBase, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.device = device
self.dtype = dtype
self.a_bit = a_bit
self.w_bit = w_bit
self.weight = None
self.qweight = None
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
"""
Initializes or resets the weight parameter using Kaiming uniform initialization.
"""
self.weight = torch.nn.Parameter(torch.empty(
(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)))
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
[docs]
def prepare_params(self) -> None:
"""
Prepares and initializes the model parameters for training.
Note:
This method MUST be called after model initialization and before training starts to ensure the weights are
properly prepared for efficient computation.
"""
raise NotImplementedError("Subclasses should implement this method.")
[docs]
def set_weight_data(self, x: torch.Tensor) -> None:
"""
Sets the weight parameter with the provided tensor.
Args:
x (torch.Tensor): The tensor to be used as the new weight parameter.
"""
self.weight = nn.Parameter(x, requires_grad=False)
[docs]
def set_quantized_weight_data(self, x: torch.Tensor) -> None:
"""
Sets the quantized weight parameter with the provided tensor.
Args:
x (torch.Tensor): The tensor to be used as the new quantized weight parameter.
"""
self.qweight = nn.Parameter(x, requires_grad=False)
[docs]
def generate_quantized_weight(self, qweight_only: bool = False) -> None:
"""
Generates and sets the quantized weight based on the current weight parameter.
This method should be overridden by subclasses to implement specific quantization logic.
Args:
qweight_only (bool, optional): If True, the original weight tensor is discarded to save memory.
"""
raise NotImplementedError("Subclasses should implement this method.")
def _check_forward(self, x: torch.Tensor) -> None:
"""
Checks the input tensor before forward pass. This method should be implemented by subclasses.
Args:
x (torch.Tensor): The input tensor to the layer.
"""
raise NotImplementedError("Subclasses should implement this method.")
@property
def opt_weight(self):
"""
Returns the proper weight parameter for the forward pass. If the model is in evaluation mode and
quantized weights are available, it returns the quantized weights; otherwise, it returns the original weights.
Returns:
torch.nn.Parameter: The optimal weight parameter for the forward pass.
"""
if not self.training and self.qweight is None:
self.generate_quantized_weight()
return self.weight if self.training else self.qweight