bitorch_engine.utils.model_helper.pad_embedding_dim

bitorch_engine.utils.model_helper.pad_embedding_dim(weight: Tensor) Tensor[source]

This function takes as input a PyTorch tensor “weight” representing the embedding matrix, and pads its embedding dimension to the smallest multiple of 8 that is greater than or equal to the current embedding dimension. It does so by calculating the remainder of the current embedding dimension divided by 8, and adding the required number of columns filled with -1 to the tensor. Finally, the function returns the padded tensor.

Parameters:

tensor (torch.Tensor) – A PyTorch tensor for storing weight parameters

Returns:

Weight tensor after padding

Return type:

tensor (torch.Tensor)