Source code for bitorch_engine.optim.diode_beta

import math
import warnings
from typing import Callable, Iterable, Tuple
import torch
from torch import nn
from torch.optim import Optimizer

from bitorch_engine.layers.qembedding.binary import BinaryEmbeddingParameter
from bitorch_engine.layers.qlinear.binary import BinaryLinearParameter
from bitorch_engine.layers.qlinear.nbit import nBitLinearParameter, MPQWeightParameter
from bitorch_engine.layers.qconv.binary import BinaryConvParameter
from bitorch_engine.layers.qconv.nbit import nBitConvParameter

from packaging import version

from .galore_projector import GaLoreProjector


[docs] def check_pytorch_version(required_version): """ Checks if the current PyTorch version meets the specified minimum version requirement. Args: required_version (str): The minimum version of PyTorch required. Raises: Exception: If the current PyTorch version is below the required minimum version. """ # Get the current version of PyTorch current_version = torch.__version__ # Compare the current version with the required minimum version if version.parse(current_version) < version.parse(required_version): raise Exception(f"Current PyTorch version {current_version} is below the required minimum version {required_version}.")
[docs] class DiodeMix(Optimizer): """ DiodeMix is a custom optimizer designed for efficient optimization, leveraging adaptive learning rates and momentum. It is particularly suited for deep learning tasks involving parameters with binary, n-bit, or standard floating-point values. This implementation is based on "Diode: Reinventing Binary Neural Networks Training with Sign Descent Optimization" Guo, Nianhui et al., 2024. Attributes: params (Iterable[nn.parameter.Parameter]): Iterable of parameters to optimize lr (float, optional): Learning rate. Defaults to 1e-4. betas (Tuple[float, float], optional): Coefficients used for computing running averages of gradient and its square. Defaults to (0.99, 0.9999). eps (float, optional): Term added to the denominator to improve numerical stability. Defaults to 1e-6. weight_decay (float, optional): Weight decay (L2 penalty). Defaults to 0.0. correct_bias (bool, optional): Whether to correct bias in adaptive learning rate. Defaults to True. dtype (torch.dtype): data type will be used in qweight update computation Methods: __init__(): Initializes the optimizer with the given parameters and options. step(closure=None): Performs a single optimization step. Raises: ValueError: If any of the parameters (learning rate, betas, epsilon) are out of their expected range. Note: This optimizer includes checks to ensure the learning rate and beta values are within valid ranges, raising ValueError if not. It supports sparse gradients with a specific error message guiding the user towards using SparseAdam instead if needed. """
[docs] def __init__( self, params: Iterable[nn.parameter.Parameter], lr: float = 1e-4, betas: Tuple[float, float] = (0.99, 0.9999), eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, dtype: torch.dtype = torch.float ): check_pytorch_version("1.5.0") # add_ with alpha if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} self.dtype = dtype super().__init__(params, defaults)
[docs] @torch.no_grad() def step(self, closure: Callable = None): """ Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. Defaults to None. Returns: The loss from the closure call, if any. Note: This method updates the weights of the parameters based on the gradients. It handles different types of parameters (binary, n-bit, or floating-point) differently to optimize their weights efficiently. For binary and n-bit parameters, it updates quantized weights directly. For standard floating-point parameters, it applies Adam-like updates. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue w, grad, lr, wd, beta1, beta2, state, eps, correct_bias = p, p.grad, group['lr'], group['weight_decay'], \ *group['betas'], self.state[p], group["eps"], group["correct_bias"] if isinstance(p, MPQWeightParameter): grad = p.privileged_grad if grad.is_sparse: raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") if "step" not in state: state['step'] = torch.zeros(1) # init GaLore Projection projector = None if "rank" in group: if "projector" not in state: state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) projector = state["projector"] grad = projector.project(grad.to(self.dtype), state["step"].item()) # exp_avg initialization if "exp_avg_s" not in state: delta = torch.rand_like(w, dtype=self.dtype).mul_(1e-3) if isinstance(p, BinaryEmbeddingParameter): state['exp_avg_s'] = -(w.data.clone().sign_().to(self.dtype).mul_(delta)) elif isinstance(p, (BinaryLinearParameter, BinaryConvParameter)): state['exp_avg_l'] = torch.zeros_like(w, dtype=self.dtype) state['exp_avg_s'] = -(w.data.clone().sign_().to(self.dtype).mul_(delta)) else: state['exp_avg_l'] = torch.zeros_like(grad, dtype=self.dtype) state['exp_avg_s'] = torch.zeros_like(grad, dtype=self.dtype) # q-layers update if isinstance(p, (BinaryLinearParameter, BinaryConvParameter, BinaryEmbeddingParameter, nBitLinearParameter, nBitConvParameter, MPQWeightParameter)): p_class = type(p) p_class.update(qweight=p, exp_avg_s=state["exp_avg_s"], exp_avg_l=state["exp_avg_l"], step=state["step"], lr=lr, weight_decay=wd, beta1=beta1, beta2=beta2, correct_bias=correct_bias, eps=eps, dtype=self.dtype, projector=projector, grad=grad) else: # print('use standard adamw') exp_avg, exp_avg_sq, step = state["exp_avg_l"], state["exp_avg_s"], state["step"] step.add_(1) # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) denom = exp_avg_sq.sqrt().add_(eps) step_size = lr if correct_bias: # No bias correction for Bert bias_correction1 = 1.0 - beta1 ** step.item() bias_correction2 = 1.0 - beta2 ** step.item() step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 # compute norm gradient norm_grad = exp_avg / denom # GaLore Projection Back if "rank" in group: norm_grad = projector.project_back(norm_grad) w.add_(norm_grad, alpha=-step_size) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: w.add_(w, alpha=(-lr * group["weight_decay"])) return loss