bitorch_engine.utils.model_helper.qweight_update_fn

bitorch_engine.utils.model_helper.qweight_update_fn(qweight: Parameter, exp_avg_s: Tensor | None = None, exp_avg_l: Tensor | None = None, step: Tensor | None = None, lr: float = 0.0001, weight_decay: float = 0.0, beta1: float = 0.99, beta2: float = 0.9999, eps: float = 1e-06, dtype=torch.float16, correct_bias=None, projector=None, grad: Tensor | None = None) None[source]

This method defines how to update quantized weights with quantized gradients. It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.

Parameters:
  • qweight (torch.nn.Parameter) – The current quantized weight parameter to be updated.

  • exp_avg_s (torch.Tensor, optional) – Exponential moving average of squared gradients. Used in optimization algorithms like Adam.

  • exp_avg_l (torch.Tensor, optional) – Exponential moving average of the gradients. Also used in optimizers like Adam.

  • step (torch.Tensor, optional) – The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.

  • lr (float, optional) – Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.

  • weight_decay (float, optional) – Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.

  • beta1 (float, optional) – The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.

  • beta2 (float, optional) – The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.

  • eps (float, optional) – A small constant for numerical stability.

  • dtype (torch.dtype, optional) – The data type to be used for computations.

  • correct_bias (optional) – Whether to apply bias correction (specific to certain models like BERT).

  • projector (optinal) – Whether use a gradient projector.

  • grad (optional) – gradient tensor will be used if projector used.

Returns:

The function is expected to update the qweight in-place and does not return anything.

Return type:

None

Raises:

NotImplementedError – Indicates that the function has not yet been implemented.