bitorch_engine.layers.qmha.binary.layer.BMHA

class bitorch_engine.layers.qmha.binary.layer.BMHA(input_dim: int, hidden_dim: int, num_heads: int, dtype: Tensor = torch.float32, *args, **kwargs)[source]

Implements a binary version of multi-head attention (MHA) where the linear transformations are executed using binary operations to improve efficiency. This class is designed to work with binary weights and can be particularly useful for deployments in resource-constrained environments or for models where computational efficiency is crucial.

dtype

Data type for the computations, typically float or binary.

Type:

torch.dtype

num_heads

Number of attention heads.

Type:

int

head_dim

Dimension of each attention head.

Type:

int

hidden_dim

Dimension of the hidden layer.

Type:

int

input_dim

Dimension of the input layer.

Type:

int

q_linear

Linear transformation for the query vector using binary operations.

Type:

BinaryLinearCutlass

v_linear

Linear transformation for the value vector using binary operations.

Type:

BinaryLinearCutlass

k_linear

Linear transformation for the key vector using binary operations.

Type:

BinaryLinearCutlass

dropout

Dropout layer to prevent overfitting.

Type:

nn.Dropout

out

Final linear layer to project the attention output back to input dimensionality.

Type:

BinaryLinearCutlass

Raises:

ValueError – If the hidden size is not a multiple of the number of attention heads.

Note: The implementation of this class is still in the EXPERIMENTAL STAGE!

Methods

__init__

Initializes the BMHA module with the specified parameters.

forward

Forward pass for the binary multi-head attention layer.

Attributes

training

__init__(input_dim: int, hidden_dim: int, num_heads: int, dtype: Tensor = torch.float32, *args, **kwargs)[source]

Initializes the BMHA module with the specified parameters. It sets up the binary linear layers for the queries, keys, and values, along with the output projection layer. It also validates that the hidden dimension is evenly divisible by the number of heads to ensure that the dimensions align properly for multi-head attention.

Parameters:
  • input_dim (int) – The size of each input vector.

  • hidden_dim (int) – The size of the hidden dimension. Must be divisible by num_heads.

  • num_heads (int) – The number of attention heads.

  • dtype (torch.dtype, optional) – The data type for computations. Defaults to torch.float.

  • *args – Variable length argument list for the parent class.

  • **kwargs – Arbitrary keyword arguments for the parent class.

Raises:

ValueError – If hidden_dim is not divisible by num_heads, an error is raised to alert the user.

forward(hidden_states: Tensor, mask: Tensor | None = None) Tuple[Tensor, Tensor][source]

Forward pass for the binary multi-head attention layer.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor of shape (batch_size, sequence_length, input_dim).

  • mask (torch.Tensor, optional) – Optional mask tensor to exclude certain positions from the attention mechanism. Shape (batch_size, sequence_length).

Returns:

A tuple containing:
  • Output tensor of shape (batch_size, sequence_length, input_dim).

  • Attention scores tensor of shape (batch_size, num_heads, sequence_length, sequence_length).

Return type:

Tuple[torch.Tensor, torch.Tensor]