bitorch_engine.utils.model_helper.init_weight

bitorch_engine.utils.model_helper.init_weight(weight: ~torch.Tensor, cls: ~typing.Type[~torch.nn.parameter.Parameter] = <class 'torch.nn.parameter.Parameter'>) Tuple[Tensor, Tensor][source]

Initializes binary parameters using pre-trained weights if available.

This function calculates the weight scale from either the provided pre-trained weights or the initial weights. It converts weights to int8 for training, achieving a 4x reduction in size, and prepares for a fully bit-packed uint8 conversion for inference, achieving a 32x reduction in size. The process aims to preserve the average magnitude of the original weights.

Parameters:
  • weight (Tensor) – The initial floating-point weight tensor.

  • cls (Type[torch.nn.Parameter]) – class of the output weight.

Returns:

A tuple containing the initialized weight as a torch.nn.Parameter in int8 format and the scale of the weight.

Return type:

Tuple[Tensor, Tensor]