bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingCuda

class bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingCuda(*args: int, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, dtype: dtype = torch.float32, **kwargs: int)[source]

Binarized version of embedding layer, currently in experimental stage. Note: This class is experimental and may not always work as expected. It supports uint8 type weight parameters (qweight) that are bit-packed, requiring a custom optimizer for training.

Parameters:
  • num_embeddings (int) – Size of the dictionary of embeddings.

  • embedding_dim (int) – The size of each embedding vector.

  • padding_idx (Optional[int]) – Specifies a padding index. Embeddings at this index will be initialized with zeros.

  • dtype (torch.dtype) – Data type of the embeddings.

weight

The original floating-point weights, not used during forward passes but necessary for initializing qweight.

Type:

Parameter

qweight

The quantized, bit-packed weights used for embeddings.

Type:

Parameter

scale_w

Row-wise scaling factor for the embedding dictionary.

Type:

torch.Tensor

Methods

__init__

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward

Forward pass to generate embeddings for input indices.

init_weight

Initializes weight parameters.

prepare_params

Prepares and initializes the binary weight (qweight) and scale factor (scale_w) for the embedding.

reset_parameters

Resets parameters, including filling the padding index with zeros if specified.

Attributes

training

__init__(*args: int, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, dtype: dtype = torch.float32, **kwargs: int) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input: Tensor) Tensor[source]

Forward pass to generate embeddings for input indices.

Parameters:

input (Tensor) – Tensor of indices for which embeddings are to be generated.

Returns:

The resulting embeddings.

Return type:

Tensor

init_weight() None[source]

Initializes weight parameters. This includes the floating-point weights (for initial setup), the quantized and bit-packed weights (qweight), and the scaling factor (scale_w).

prepare_params() None[source]

Prepares and initializes the binary weight (qweight) and scale factor (scale_w) for the embedding. Must be called after model initialization and checkpoint loading.

One can use “prepare_bie_layers” method from project_root.utils.model_helper to call this function.

reset_parameters(weight: Tensor) None[source]

Resets parameters, including filling the padding index with zeros if specified.

Parameters:

weight (torch.Tensor) – The weight tensor to reset.