bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryMatMul
- class bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryMatMul(dtype=torch.float32, *args, **kwargs)[source]
A PyTorch module for binary matrix multiplication. This module wraps the BinaryMatMulFunction and includes parameters for the clipping values used in the binary operations. It is designed for integration into neural networks where binary operations can offer computational benefits.
- dtype
Data type of the clipping parameters. Defaults to torch.float.
- Type:
torch.dtype
- x_clip
Clipping parameter for the first input tensor.
- Type:
torch.nn.Parameter
- y_clip
Clipping parameter for the second input tensor.
- Type:
torch.nn.Parameter
Methods
Initializes the BinaryMatMul module with optional dtype argument for the clipping parameters.
Forward pass of the BinaryMatMul module.
Sets the clipping parameters based on the input tensors.
Attributes
training
- __init__(dtype=torch.float32, *args, **kwargs)[source]
Initializes the BinaryMatMul module with optional dtype argument for the clipping parameters.
- Parameters:
dtype (torch.dtype, optional) – The data type for the clipping parameters.
- forward(x: Tensor, y: Tensor) Tensor [source]
Forward pass of the BinaryMatMul module. Ensures input tensors are of appropriate dimensions and applies the binary matrix multiplication function.
- Parameters:
x (torch.Tensor) – The first input tensor.
y (torch.Tensor) – The second input tensor.
- Returns:
The result of the binary matrix multiplication.
- Return type:
torch.Tensor
- set_activation_scale(x: Tensor, y: Tensor) None [source]
Sets the clipping parameters based on the input tensors. If the clipping values are not already set, they are initialized based on the mean absolute values of the inputs.
- Parameters:
x (torch.Tensor) – The first input tensor.
y (torch.Tensor) – The second input tensor.