bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward

class bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingForward(*args, **kwargs)[source]

Experimental class for forward pass of binary embeddings. Note: This class is experimental and may not always work as expected. Due to the use of uint8 type for weight parameters (qweight), which are bit-packed binary weights, a custom optimizer is necessary for training.

Parameters:
  • input (torch.Tensor) – Input tensor with shape (batch_size, seq_length).

  • weight (torch.Tensor) – The embedding weight matrix, bit-packed.

  • embed_scale (torch.Tensor) – Scaling factor for the embeddings.

  • ori_embedding_dim (int) – Original embedding dimension before packing.

  • is_train (bool) – Flag indicating if the operation is in training mode.

Returns:

The resulting tensor after applying embedding lookup and unpacking.

Return type:

torch.Tensor

Methods

backward

Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights.

forward

Forward pass for binary embedding lookup.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass for the BinaryEmbeddingForward function, providing gradients for the input and packed weights.

This method is designed to handle the gradient flow for bit-packed parameters, specifically for binary embeddings, ensuring the correct update behavior in conjunction with a custom optimizer designed for such data types.

Parameters:
  • ctx (torch.autograd.function.BackwardCFunction) – The context object where saved tensors are retrieved.

  • output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the forward pass.

Returns:

A tuple containing gradients for the inputs used in the forward pass. Most elements are None

as only specific parameters require gradients.

Return type:

Tuple

Note

The optimizer update behavior suggested for this setup involves a specialized approach. Here is an example of how to perform the necessary operations:

  • Perform XOR operation: apply XOR between qweight and grad_weight

  • Create a mask identifying non-zero positions in grad_weight

  • Update qweight only at positions where grad_weight is non-zero, keeping the original qweight values elsewhere

This process is essential due to the binary and packed nature of the weights, requiring careful manipulation to ensure correct updates during training.

“sparse_update_embedding_qweight” method can be used for qweight-update in an optimizer.

static forward(ctx, input: Tensor, qweight: Tensor, embed_scale: Tensor, ori_embedding_dim: int, is_train: bool) Tensor[source]

Forward pass for binary embedding lookup.

This function performs embedding lookup for each index in the input tensor, using bit-packed binary weights. It supports dynamic scaling and unpacking of the weights to retrieve the original embeddings.

Parameters:
  • ctx – Context object for backward pass.

  • input (torch.Tensor) – Input tensor containing indices, with shape (batch_size, seq_length).

  • qweight (torch.Tensor) – Bit-packed binary weights tensor for embeddings.

  • embed_scale (torch.Tensor) – Scaling factors for the embeddings, to be applied after unpacking.

  • ori_embedding_dim (int) – Original embedding dimension before packing.

  • is_train (bool) – Flag indicating whether the operation is in training mode.

Returns:

The resulting embedding tensor after lookup and scaling, with shape (batch_size, seq_length, original_embedding_dim).

Return type:

torch.Tensor

Note

This function saves necessary tensors for backward pass if in training mode.