bitorch_engine.utils.model_helper

Functions

binary_matmul_forward_post_processing

Post-processes the output tensor of a binary matrix multiplication operation.

flatten_x

Flattens a 3D tensor into a 2D tensor by combining the first two dimensions.

init_weight

Initializes binary parameters using pre-trained weights if available.

load_checkpoint

Loads a checkpoint into a given model.

pack_bie_layers

Packs the weights of quantization layers in a given model to prepare for efficient storage.

pad_embedding_dim

This function takes as input a PyTorch tensor "weight" representing the embedding matrix, and pads its embedding dimension to the smallest multiple of 8 that is greater than or equal to the current embedding dimension.

pad_last_2_dims_to_multiple_of_128

Pad the last two dimensions of a PyTorch tensor to the nearest multiple of 128.

prepare_bie_layers

Prepares binary and n-bit quantized layers within a given model for training or inference.

qweight_update_fn

This method defines how to update quantized weights with quantized gradients.

save_checkpoint

Saves the state of a quantized PyTorch model in a bit-packed format.

unflatten_x

Unflattens a 2D tensor back into a 3D tensor using the original shape, reversing the operation performed by flatten_x.

update_zeros

Updates the zeros attribute of the qweight object based on its layer type.