Source code for bitorch_engine.layers.qconv.nbit.layer

import math
from torch import nn
import torch
from torch.nn import init
from bitorch_engine.utils.model_helper import qweight_update_fn

[docs] class nBitConvParameter(torch.nn.Parameter): """ A custom parameter class for n-bit conv layer, extending torch.nn.Parameter. This class is designed to support n-bit conv layers, particularly useful in models requiring efficient memory usage and specialized optimization techniques. Args: data (torch.Tensor, optional): The initial data for the parameter. Defaults to None. requires_grad (bool, optional): Flag indicating whether gradients should be computed for this parameter in the backward pass. Defaults to True. """ def __new__(cls, data: torch.Tensor=None, requires_grad: bool=True ): return super().__new__(cls, data=data, requires_grad=requires_grad)
[docs] @staticmethod def update(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None, exp_avg_l: torch.Tensor=None, step: torch.Tensor=None, lr:float=1e-4, weight_decay:float=0.0, beta1:float=0.99, beta2:float=0.9999, eps: float = 1e-6, dtype=torch.half, correct_bias=None, projector=None, grad:torch.Tensor=None) -> None: """ This method defines how to update quantized weights with quantized gradients. It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm. Args: qweight (torch.nn.Parameter): The current quantized weight parameter to be updated. exp_avg_s (torch.Tensor, optional): Exponential moving average of squared gradients. Used in optimization algorithms like Adam. exp_avg_l (torch.Tensor, optional): Exponential moving average of the gradients. Also used in optimizers like Adam. step (torch.Tensor, optional): The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process. lr (float, optional): Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function. weight_decay (float, optional): Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights. beta1 (float, optional): The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam. beta2 (float, optional): The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers. eps (float, optional): A small constant for numerical stability. dtype (torch.dtype, optional): The data type to be used for computations. correct_bias (optional): Whether to apply bias correction (specific to certain models like BERT). projector (optinal): Whether use a gradient projector. grad (optional): gradient tensor will be used if projector used. Returns: None: The function is expected to update the `qweight` in-place and does not return anything. Raises: NotImplementedError: Indicates that the function has not yet been implemented. """ assert isinstance(qweight, nBitConvParameter), 'Error: the type of qweight must be ' \ 'nBitConvParameter. ' qweight_update_fn(qweight=qweight, exp_avg_s=exp_avg_s, exp_avg_l=exp_avg_l, step=step, lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2, correct_bias=correct_bias, eps=eps, dtype=dtype, projector=projector, grad=grad)
[docs] class nBitConv2dBase(nn.Module):
[docs] def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, a_bit: int = 4, w_bit: int = 4, device=None, dtype=torch.float): """ Initializes the nBitConv2dBase module, a base class for creating convolutional layers with n-bit quantized weights. Args: in_channels (int): The number of input channels in the convolutional layer. out_channels (int): The number of output channels in the convolutional layer. kernel_size (int): The size of the convolutional kernel. stride (int, optional): The stride of the convolution. Defaults to 1. padding (int, optional): The padding added to all sides of the input tensor. Defaults to 0. dilation (int, optional): The spacing between kernel elements. Defaults to 1. a_bit (int, optional): The bit-width for activation quantization. Defaults to 4. w_bit (int, optional): The bit-width for weight quantization. Defaults to 4. device (optional): The device on which the module will be allocated. Defaults to None. dtype (optional): The desired data type of the parameters. Defaults to torch.float. """ super(nBitConv2dBase, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.device = device self.dtype = dtype self.a_bit = a_bit self.w_bit = w_bit self.weight = None self.qweight = None self.reset_parameters()
[docs] def reset_parameters(self) -> None: """ Initializes or resets the weight parameter using Kaiming uniform initialization. """ self.weight = torch.nn.Parameter(torch.empty( (self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))) init.kaiming_uniform_(self.weight, a=math.sqrt(5))
[docs] def prepare_params(self) -> None: """ Prepares and initializes the model parameters for training. Note: This method MUST be called after model initialization and before training starts to ensure the weights are properly prepared for efficient computation. """ raise NotImplementedError("Subclasses should implement this method.")
[docs] def set_weight_data(self, x: torch.Tensor) -> None: """ Sets the weight parameter with the provided tensor. Args: x (torch.Tensor): The tensor to be used as the new weight parameter. """ self.weight = nn.Parameter(x, requires_grad=False)
[docs] def set_quantized_weight_data(self, x: torch.Tensor) -> None: """ Sets the quantized weight parameter with the provided tensor. Args: x (torch.Tensor): The tensor to be used as the new quantized weight parameter. """ self.qweight = nn.Parameter(x, requires_grad=False)
[docs] def generate_quantized_weight(self, qweight_only: bool = False) -> None: """ Generates and sets the quantized weight based on the current weight parameter. This method should be overridden by subclasses to implement specific quantization logic. Args: qweight_only (bool, optional): If True, the original weight tensor is discarded to save memory. """ raise NotImplementedError("Subclasses should implement this method.")
def _check_forward(self, x: torch.Tensor) -> None: """ Checks the input tensor before forward pass. This method should be implemented by subclasses. Args: x (torch.Tensor): The input tensor to the layer. """ raise NotImplementedError("Subclasses should implement this method.") @property def opt_weight(self): """ Returns the proper weight parameter for the forward pass. If the model is in evaluation mode and quantized weights are available, it returns the quantized weights; otherwise, it returns the original weights. Returns: torch.nn.Parameter: The optimal weight parameter for the forward pass. """ if not self.training and self.qweight is None: self.generate_quantized_weight() return self.weight if self.training else self.qweight