bitorch_engine.utils.model_helper.qweight_update_fn
- bitorch_engine.utils.model_helper.qweight_update_fn(qweight: Parameter, exp_avg_s: Tensor | None = None, exp_avg_l: Tensor | None = None, step: Tensor | None = None, lr: float = 0.0001, weight_decay: float = 0.0, beta1: float = 0.99, beta2: float = 0.9999, eps: float = 1e-06, dtype=torch.float16, correct_bias=None, projector=None, grad: Tensor | None = None) None [source]
This method defines how to update quantized weights with quantized gradients. It may involve operations such as applying momentum or adjusting weights based on some optimization algorithm.
- Parameters:
qweight (torch.nn.Parameter) – The current quantized weight parameter to be updated.
exp_avg_s (torch.Tensor, optional) – Exponential moving average of squared gradients. Used in optimization algorithms like Adam.
exp_avg_l (torch.Tensor, optional) – Exponential moving average of the gradients. Also used in optimizers like Adam.
step (torch.Tensor, optional) – The current step or iteration in the optimization process. Can be used to adjust learning rate or for other conditional operations in the update process.
lr (float, optional) – Learning rate. A hyperparameter that determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay (float, optional) – Weight decay (L2 penalty). A regularization term that helps to prevent overfitting by penalizing large weights.
beta1 (float, optional) – The exponential decay rate for the first moment estimates. A hyperparameter for optimizers like Adam.
beta2 (float, optional) – The exponential decay rate for the second-moment estimates. Another hyperparameter for Adam-like optimizers.
eps (float, optional) – A small constant for numerical stability.
dtype (torch.dtype, optional) – The data type to be used for computations.
correct_bias (optional) – Whether to apply bias correction (specific to certain models like BERT).
projector (optinal) – Whether use a gradient projector.
grad (optional) – gradient tensor will be used if projector used.
- Returns:
The function is expected to update the qweight in-place and does not return anything.
- Return type:
None
- Raises:
NotImplementedError – Indicates that the function has not yet been implemented.