bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward
- class bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward(*args, **kwargs)[source]
- Experimental class for forward pass of binary embeddings. Note: This class is experimental and may not always work as expected. Due to the use of uint8 type for weight parameters (qweight), which are bit-packed binary weights, a custom optimizer is necessary for training. - Parameters:
- input (torch.Tensor) – Input tensor with shape (batch_size, seq_length). 
- weight (torch.Tensor) – The embedding weight matrix, bit-packed. 
- embed_scale (torch.Tensor) – Scaling factor for the embeddings. 
- ori_embedding_dim (int) – Original embedding dimension before packing. 
- is_train (bool) – Flag indicating if the operation is in training mode. 
 
- Returns:
- The resulting tensor after applying embedding lookup and unpacking. 
- Return type:
- torch.Tensor 
 - Methods - Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights. - Forward pass for binary embedding lookup. - Attributes - static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]
- Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights. - This method is designed to handle the gradient flow for bit-packed parameters, specifically for binary embeddings, ensuring the correct update behavior in conjunction with a custom optimizer designed for such data types. - Parameters:
- ctx (torch.autograd.function.BackwardCFunction) – The context object where saved tensors are retrieved. 
- output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the forward pass. 
 
- Returns:
- A tuple containing gradients for the inputs used in the forward pass. Most elements are None
- as only specific parameters require gradients. 
 
- Return type:
- Tuple 
 - Note - The optimizer update behavior suggested for this setup involves a specialized approach. Here is an example of how to perform the necessary operations: - Perform XOR operation: apply XOR between qweight and grad_weight 
- Create a mask identifying non-zero positions in grad_weight 
- Update qweight only at positions where grad_weight is non-zero, keeping the original qweight values elsewhere 
 - This process is essential due to the binary and packed nature of the weights, requiring careful manipulation to ensure correct updates during training. - “sparse_update_embedding_qweight” method can be used for qweight-update in an optimizer. 
 - static forward(ctx, input: Tensor, qweight: Tensor, embed_scale: Tensor, ori_embedding_dim: int, is_train: bool) Tensor[source]
- Forward pass for binary embedding lookup. - This function performs embedding lookup for each index in the input tensor, using bit-packed binary weights. It supports dynamic scaling and unpacking of the weights to retrieve the original embeddings. - Parameters:
- ctx – Context object for backward pass. 
- input (torch.Tensor) – Input tensor containing indices, with shape (batch_size, seq_length). 
- qweight (torch.Tensor) – Bit-packed binary weights tensor for embeddings. 
- embed_scale (torch.Tensor) – Scaling factors for the embeddings, to be applied after unpacking. 
- ori_embedding_dim (int) – Original embedding dimension before packing. 
- is_train (bool) – Flag indicating whether the operation is in training mode. 
 
- Returns:
- The resulting embedding tensor after lookup and scaling, with shape (batch_size, seq_length, original_embedding_dim). 
- Return type:
- torch.Tensor 
 - Note - This function saves necessary tensors for backward pass if in training mode.