bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingCuda
- class bitorch_engine.layers.qembedding.binary.layer.BinaryEmbeddingCuda(*args: int, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, dtype: dtype = torch.float32, **kwargs: int)[source]
Binarized version of embedding layer, currently in experimental stage. Note: This class is experimental and may not always work as expected. It supports uint8 type weight parameters (qweight) that are bit-packed, requiring a custom optimizer for training.
- Parameters:
num_embeddings (int) – Size of the dictionary of embeddings.
embedding_dim (int) – The size of each embedding vector.
padding_idx (Optional[int]) – Specifies a padding index. Embeddings at this index will be initialized with zeros.
dtype (torch.dtype) – Data type of the embeddings.
- weight
The original floating-point weights, not used during forward passes but necessary for initializing qweight.
- Type:
Parameter
- qweight
The quantized, bit-packed weights used for embeddings.
- Type:
Parameter
- scale_w
Row-wise scaling factor for the embedding dictionary.
- Type:
torch.Tensor
Methods
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Forward pass to generate embeddings for input indices.
Initializes weight parameters.
Prepares and initializes the binary weight (qweight) and scale factor (scale_w) for the embedding.
Resets parameters, including filling the padding index with zeros if specified.
Attributes
training
- __init__(*args: int, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, dtype: dtype = torch.float32, **kwargs: int) None [source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input: Tensor) Tensor [source]
Forward pass to generate embeddings for input indices.
- Parameters:
input (Tensor) – Tensor of indices for which embeddings are to be generated.
- Returns:
The resulting embeddings.
- Return type:
Tensor
- init_weight() None [source]
Initializes weight parameters. This includes the floating-point weights (for initial setup), the quantized and bit-packed weights (qweight), and the scaling factor (scale_w).