Source code for bitorch_engine.optim.galore_projector

"""
This is the original Galore optimizer implementation from `Galore repo <https://github.com/jiaweizzhao/GaLore/tree/master/galore_torch>`_
with `Apache-2.0 License <https://github.com/jiaweizzhao/GaLore?tab=Apache-2.0-1-ov-file>`_

@misc{zhao2024galore,
      title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection},
      author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
      year={2024},
      eprint={2403.03507},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
"""

import torch

[docs] class GaLoreProjector:
[docs] def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): self.rank = rank self.verbose = verbose self.update_proj_gap = update_proj_gap self.scale = scale self.ortho_matrix = None self.proj_type = proj_type
def project(self, full_rank_grad, iter): if self.proj_type == 'std': if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) else: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == 'reverse_std': if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad) else: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t()) elif self.proj_type == 'right': if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) elif self.proj_type == 'left': if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == 'full': if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() return low_rank_grad def project_back(self, low_rank_grad): if self.proj_type == 'std': if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) else: full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) elif self.proj_type == 'reverse_std': if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) else: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) elif self.proj_type == 'right': full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) elif self.proj_type == 'left': full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) elif self.proj_type == 'full': full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] return full_rank_grad * self.scale # svd decomposition def get_orthogonal_matrix(self, weights, rank, type): module_params = weights if module_params.data.dtype != torch.float: float_data = False original_type = module_params.data.dtype original_device = module_params.data.device matrix = module_params.data.float() else: float_data = True matrix = module_params.data U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) #make the smaller matrix always to be orthogonal matrix if type=='right': A = U[:, :rank] @ torch.diag(s[:rank]) B = Vh[:rank, :] if not float_data: B = B.to(original_device).type(original_type) return B elif type=='left': A = U[:, :rank] B = torch.diag(s[:rank]) @ Vh[:rank, :] if not float_data: A = A.to(original_device).type(original_type) return A elif type=='full': A = U[:, :rank] B = Vh[:rank, :] if not float_data: A = A.to(original_device).type(original_type) B = B.to(original_device).type(original_type) return [A, B] else: raise ValueError('type should be left, right or full')