bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearForward
- class bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearForward(*args, **kwargs)[source]
Implements the forward and backward passes of a binary linear layer.
This class supports binary linear operations where both the inputs and weights are binarized. It integrates scaling and value range conversion directly within the forward pass, and provides gradient calculations for backpropagation with respect to the input, weights, and scaling factors in the backward pass.
The forward method supports any input shape by flattening the input, and it can optionally save tensors for the backward pass if training is enabled.
- Parameters:
input (torch.Tensor) – The input tensor.
weight (torch.Tensor) – The weight tensor, expected to be binary.
scale_a (torch.Tensor) – The scaling factor for the activation.
scale_w (torch.Tensor) – The scaling factor for the weights.
gemm_kernel_id (int) – An identifier for the GEMM kernel to be used in the CUTLASS library.
is_train (bool) – Flag indicating whether the operation is in training mode.
- Returns:
- The output tensor from the binary linear operation,
scaled and converted back to the input’s dtype.
- Return type:
torch.Tensor
Note
This implementation relies on the CUTLASS library for efficient binary linear operations.
Methods
Implements the backward pass of a binary linear layer.
Define the forward of the custom autograd Function.
Attributes
- static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...] [source]
Implements the backward pass of a binary linear layer.
This method calculates the gradients with respect to the input tensor, weight tensor, and the scaling factor for the activation. It also handles gradient clipping for the input tensor based on its value range after scaling.
- Parameters:
ctx (Any) – The autograd context, storing saved tensors from the forward pass.
output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the layer.
- Returns:
- A tuple containing gradients with respect to the input tensor, weight tensor,
scaling factor for the activation, followed by three None placeholders for gradients that are not calculated (scale_w, gemm_kernel_id, is_train).
- Return type:
tuple
Note
This method assumes that the weight tensor is of type int8 during the gradient calculation, and it performs the sign operation on weights to maintain the binarized nature. The method supports backpropagation through layers that use binary weights and activations.
- static forward(ctx, input: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, gemm_kernel_id: int, is_train: bool) Tensor [source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.