"""
This is the original Galore optimizer implementation from `Galore repo <https://github.com/jiaweizzhao/GaLore/tree/master/galore_torch>`_
with `Apache-2.0 License <https://github.com/jiaweizzhao/GaLore?tab=Apache-2.0-1-ov-file>`_
@misc{zhao2024galore,
title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection},
author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
year={2024},
eprint={2403.03507},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
"""
import torch
[docs]
class GaLoreProjector:
[docs]
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
def project(self, full_rank_grad, iter):
if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
elif self.proj_type == 'right':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == 'left':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'full':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
return low_rank_grad
def project_back(self, low_rank_grad):
if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'right':
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'left':
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'full':
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
#make the smaller matrix always to be orthogonal matrix
if type=='right':
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type=='left':
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type=='full':
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError('type should be left, right or full')