bitorch_engine.layers.qlinear.binary.cuda.layer.BinaryLinearForward

class bitorch_engine.layers.qlinear.binary.cuda.layer.BinaryLinearForward(*args, **kwargs)[source]

Implements the forward and backward passes for binary linear operation.

This class performs the binary linear transformation using custom CUDA operations, suitable for efficiently processing binary weights and activations in deep learning models.

ctx

The context object used to save information for backward computation.

input

The input tensor for the linear operation.

weight

The binary weight tensor.

bmm_type

An enumeration indicating the type of binary matrix multiplication kernel to use.

scale_a

The scaling factor for the input activation.

scale_w

The scaling factor for the weight.

is_train

training or eval

Methods

backward

Backward pass for the binary linear operation, computing gradients for input, weight, and scale factors based on the output gradient received from subsequent layers.

forward

Forward pass for binary linear operation.

Attributes

static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...][source]

Backward pass for the binary linear operation, computing gradients for input, weight, and scale factors based on the output gradient received from subsequent layers.

Parameters:
  • ctx (torch.autograd.function.BackwardCFunction) – The autograd context that stores information from the forward pass, including saved tensors for use in gradient calculations.

  • output_gradient (torch.Tensor) – The gradient of the loss with respect to the output of the binary linear operation. This tensor is used to compute the gradients for the input and weight tensors.

Returns:

A tuple containing gradients with respect to the

input tensor (grad_input), weight tensor (grad_weight), a None placeholder for the bias (since this operation doesn’t involve a bias term), and the scaling factor for the activation (grad_scale_a). The gradient for the weight scaling factor is not computed and thus returned as None.

Return type:

Tuple[torch.Tensor, …]

Note

  • The gradients computation involves sign operations and scaling by the respective scaling factors (scale_w for weights and scale_a for activations) to account for the binarization effect in the forward pass.

  • Gradient clipping is applied to grad_input to ensure the updates remain within the expected range, reflecting the constraints of using binary weights and activations.

  • This method requires a custom optimizer capable of handling int8 gradients and weights, as the standard optimizers may not support direct updates with int8 tensors.

static forward(ctx, input: Tensor, weight: Tensor, bmm_type: BMM, scale_a: Tensor, scale_w: Tensor, is_train: bool) Tensor[source]

Forward pass for binary linear operation.

Parameters:
  • ctx – The context object.

  • input – The input tensor.

  • weight – The binary weight tensor.

  • bmm_type – The binary matrix multiplication kernel type.

  • scale_a – The scaling factor for the input activation.

  • scale_w – The scaling factor for the weight.

Returns:

The result of the binary linear operation scaled by the activation and weight scaling factors.