bitorch_engine.layers.qmha.binary.layer.BMHA
- class bitorch_engine.layers.qmha.binary.layer.BMHA(input_dim: int, hidden_dim: int, num_heads: int, dtype: Tensor = torch.float32, *args, **kwargs)[source]
Implements a binary version of multi-head attention (MHA) where the linear transformations are executed using binary operations to improve efficiency. This class is designed to work with binary weights and can be particularly useful for deployments in resource-constrained environments or for models where computational efficiency is crucial.
- dtype
Data type for the computations, typically float or binary.
- Type:
torch.dtype
- num_heads
Number of attention heads.
- Type:
int
- head_dim
Dimension of each attention head.
- Type:
int
Dimension of the hidden layer.
- Type:
int
- input_dim
Dimension of the input layer.
- Type:
int
- q_linear
Linear transformation for the query vector using binary operations.
- Type:
- v_linear
Linear transformation for the value vector using binary operations.
- Type:
- k_linear
Linear transformation for the key vector using binary operations.
- Type:
- dropout
Dropout layer to prevent overfitting.
- Type:
nn.Dropout
- out
Final linear layer to project the attention output back to input dimensionality.
- Type:
- Raises:
ValueError – If the hidden size is not a multiple of the number of attention heads.
Note: The implementation of this class is still in the EXPERIMENTAL STAGE!
Methods
Initializes the BMHA module with the specified parameters.
Forward pass for the binary multi-head attention layer.
Attributes
training
- __init__(input_dim: int, hidden_dim: int, num_heads: int, dtype: Tensor = torch.float32, *args, **kwargs)[source]
Initializes the BMHA module with the specified parameters. It sets up the binary linear layers for the queries, keys, and values, along with the output projection layer. It also validates that the hidden dimension is evenly divisible by the number of heads to ensure that the dimensions align properly for multi-head attention.
- Parameters:
input_dim (int) – The size of each input vector.
hidden_dim (int) – The size of the hidden dimension. Must be divisible by num_heads.
num_heads (int) – The number of attention heads.
dtype (torch.dtype, optional) – The data type for computations. Defaults to torch.float.
*args – Variable length argument list for the parent class.
**kwargs – Arbitrary keyword arguments for the parent class.
- Raises:
ValueError – If hidden_dim is not divisible by num_heads, an error is raised to alert the user.
- forward(hidden_states: Tensor, mask: Tensor | None = None) Tuple[Tensor, Tensor] [source]
Forward pass for the binary multi-head attention layer.
- Parameters:
hidden_states (torch.Tensor) – Input tensor of shape (batch_size, sequence_length, input_dim).
mask (torch.Tensor, optional) – Optional mask tensor to exclude certain positions from the attention mechanism. Shape (batch_size, sequence_length).
- Returns:
- A tuple containing:
Output tensor of shape (batch_size, sequence_length, input_dim).
Attention scores tensor of shape (batch_size, num_heads, sequence_length, sequence_length).
- Return type:
Tuple[torch.Tensor, torch.Tensor]