bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearForward

class bitorch_engine.layers.qlinear.binary.cutlass.layer.BinaryLinearForward(*args, **kwargs)[source]

Implements the forward and backward passes of a binary linear layer.

This class supports binary linear operations where both the inputs and weights are binarized. It integrates scaling and value range conversion directly within the forward pass, and provides gradient calculations for backpropagation with respect to the input, weights, and scaling factors in the backward pass.

The forward method supports any input shape by flattening the input, and it can optionally save tensors for the backward pass if training is enabled.

Parameters:
  • input (torch.Tensor) – The input tensor.

  • weight (torch.Tensor) – The weight tensor, expected to be binary.

  • scale_a (torch.Tensor) – The scaling factor for the activation.

  • scale_w (torch.Tensor) – The scaling factor for the weights.

  • gemm_kernel_id (int) – An identifier for the GEMM kernel to be used in the CUTLASS library.

  • is_train (bool) – Flag indicating whether the operation is in training mode.

Returns:

The output tensor from the binary linear operation,

scaled and converted back to the input’s dtype.

Return type:

torch.Tensor

Note

This implementation relies on the CUTLASS library for efficient binary linear operations.

Methods

backward

Implements the backward pass of a binary linear layer.

forward

Define the forward of the custom autograd Function.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Implements the backward pass of a binary linear layer.

This method calculates the gradients with respect to the input tensor, weight tensor, and the scaling factor for the activation. It also handles gradient clipping for the input tensor based on its value range after scaling.

Parameters:
  • ctx (Any) – The autograd context, storing saved tensors from the forward pass.

  • output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the layer.

Returns:

A tuple containing gradients with respect to the input tensor, weight tensor,

scaling factor for the activation, followed by three None placeholders for gradients that are not calculated (scale_w, gemm_kernel_id, is_train).

Return type:

tuple

Note

This method assumes that the weight tensor is of type int8 during the gradient calculation, and it performs the sign operation on weights to maintain the binarized nature. The method supports backpropagation through layers that use binary weights and activations.

static forward(ctx, input: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, gemm_kernel_id: int, is_train: bool) Tensor[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.