bitorch_engine.layers.qconv.nbit.layer.nBitConvParameter

class bitorch_engine.layers.qconv.nbit.layer.nBitConvParameter(data: Tensor | None = None, requires_grad: bool = True)[source]

A custom parameter class for n-bit conv layer, extending torch.nn.Parameter.

This class is designed to support n-bit conv layers, particularly useful in models requiring efficient memory usage and specialized optimization techniques.

Parameters:
  • data (torch.Tensor, optional) – The initial data for the parameter. Defaults to None.

  • requires_grad (bool, optional) – Flag indicating whether gradients should be computed for this parameter in the backward pass. Defaults to True.

Methods

update

This method defines how to update quantized weights with quantized gradients.

Attributes

static update(qweight: Parameter, exp_avg_s: Tensor | None = None, exp_avg_l: Tensor | None = None, step: Tensor | None = None, lr: float = 0.0001, weight_decay: float = 0.0, beta1: float = 0.99, beta2: float = 0.9999, eps: float = 1e-06, dtype=torch.float16, correct_bias=None, projector=None, grad: Tensor | None = None) None[source]

This method defines how to update quantized weights with quantized gradients. It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.

Parameters:
  • qweight (torch.nn.Parameter) – The current quantized weight parameter to be updated.

  • exp_avg_s (torch.Tensor, optional) – Exponential moving average of squared gradients. Used in optimization algorithms like Adam.

  • exp_avg_l (torch.Tensor, optional) – Exponential moving average of the gradients. Also used in optimizers like Adam.

  • step (torch.Tensor, optional) – The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.

  • lr (float, optional) – Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.

  • weight_decay (float, optional) – Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.

  • beta1 (float, optional) – The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.

  • beta2 (float, optional) – The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.

  • eps (float, optional) – A small constant for numerical stability.

  • dtype (torch.dtype, optional) – The data type to be used for computations.

  • correct_bias (optional) – Whether to apply bias correction (specific to certain models like BERT).

  • projector (optinal) – Whether use a gradient projector.

  • grad (optional) – gradient tensor will be used if projector used.

Returns:

The function is expected to update the qweight in-place and does not return anything.

Return type:

None

Raises:

NotImplementedError – Indicates that the function has not yet been implemented.