bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward
- class bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward(*args, **kwargs)[source]
Experimental class for forward pass of binary embeddings. Note: This class is experimental and may not always work as expected. Due to the use of uint8 type for weight parameters (qweight), which are bit-packed binary weights, a custom optimizer is necessary for training.
- Parameters:
input (torch.Tensor) – Input tensor with shape (batch_size, seq_length).
weight (torch.Tensor) – The embedding weight matrix, bit-packed.
embed_scale (torch.Tensor) – Scaling factor for the embeddings.
ori_embedding_dim (int) – Original embedding dimension before packing.
is_train (bool) – Flag indicating if the operation is in training mode.
- Returns:
The resulting tensor after applying embedding lookup and unpacking.
- Return type:
torch.Tensor
Methods
Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights.
Forward pass for binary embedding lookup.
Attributes
- static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...] [source]
Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights.
This method is designed to handle the gradient flow for bit-packed parameters, specifically for binary embeddings, ensuring the correct update behavior in conjunction with a custom optimizer designed for such data types.
- Parameters:
ctx (torch.autograd.function.BackwardCFunction) – The context object where saved tensors are retrieved.
output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the forward pass.
- Returns:
- A tuple containing gradients for the inputs used in the forward pass. Most elements are None
as only specific parameters require gradients.
- Return type:
Tuple
Note
The optimizer update behavior suggested for this setup involves a specialized approach. Here is an example of how to perform the necessary operations:
Perform XOR operation: apply XOR between qweight and grad_weight
Create a mask identifying non-zero positions in grad_weight
Update qweight only at positions where grad_weight is non-zero, keeping the original qweight values elsewhere
This process is essential due to the binary and packed nature of the weights, requiring careful manipulation to ensure correct updates during training.
“sparse_update_embedding_qweight” method can be used for qweight-update in an optimizer.
- static forward(ctx, input: Tensor, qweight: Tensor, embed_scale: Tensor, ori_embedding_dim: int, is_train: bool) Tensor [source]
Forward pass for binary embedding lookup.
This function performs embedding lookup for each index in the input tensor, using bit-packed binary weights. It supports dynamic scaling and unpacking of the weights to retrieve the original embeddings.
- Parameters:
ctx – Context object for backward pass.
input (torch.Tensor) – Input tensor containing indices, with shape (batch_size, seq_length).
qweight (torch.Tensor) – Bit-packed binary weights tensor for embeddings.
embed_scale (torch.Tensor) – Scaling factors for the embeddings, to be applied after unpacking.
ori_embedding_dim (int) – Original embedding dimension before packing.
is_train (bool) – Flag indicating whether the operation is in training mode.
- Returns:
The resulting embedding tensor after lookup and scaling, with shape (batch_size, seq_length, original_embedding_dim).
- Return type:
torch.Tensor
Note
This function saves necessary tensors for backward pass if in training mode.