bitorch_engine.utils.model_helper.init_weight
- bitorch_engine.utils.model_helper.init_weight(weight: ~torch.Tensor, cls: ~typing.Type[~torch.nn.parameter.Parameter] = <class 'torch.nn.parameter.Parameter'>) Tuple[Tensor, Tensor] [source]
Initializes binary parameters using pre-trained weights if available.
This function calculates the weight scale from either the provided pre-trained weights or the initial weights. It converts weights to int8 for training, achieving a 4x reduction in size, and prepares for a fully bit-packed uint8 conversion for inference, achieving a 32x reduction in size. The process aims to preserve the average magnitude of the original weights.
- Parameters:
weight (Tensor) – The initial floating-point weight tensor.
cls (Type[torch.nn.Parameter]) – class of the output weight.
- Returns:
A tuple containing the initialized weight as a torch.nn.Parameter in int8 format and the scale of the weight.
- Return type:
Tuple[Tensor, Tensor]