bitorch_engine.layers.qconv.nbit.cutlass.layer.Q4Conv2dCutlassForward
- class bitorch_engine.layers.qconv.nbit.cutlass.layer.Q4Conv2dCutlassForward(*args, **kwargs)[source]
A custom forward function for 4-bit quantized convolution using CUTLASS kernels.
This class implements the forward and backward passes for a 4-bit quantized convolution operation, intended for use with PyTorch’s autograd mechanism. It utilizes CUTLASS low-precision computation primitives for efficient GPU acceleration.
Methods
Backward pass for the 4-bit quantized convolution.
Forward pass for the 4-bit quantized convolution.
Attributes
- static backward(ctx: BackwardCFunction, output_gradient: Tensor) Tuple[Tensor, ...] [source]
Backward pass for the 4-bit quantized convolution.
This method computes the gradients of the input tensor and weight tensor based on the gradient of the layer’s output. It also calculates the gradient of the scale factor used in quantization. The method uses saved tensors and attributes from the forward pass for this computation.
- Parameters:
ctx (torch.autograd.function.BackwardCFunction) – Context object with saved tensors and options.
output_gradient (torch.Tensor) – Gradient of the loss with respect to the output of the forward pass.
- Returns:
Gradients of the input tensor, weight tensor, and scale factors, with None placeholders for non-differentiable parameters.
- Return type:
Tuple[torch.Tensor, …]
- static forward(ctx, x: Tensor, weight: Tensor, scale_a: Tensor, scale_w: Tensor, is_train: bool, kernel_size: int, stride: int, padding: int, dilation: int) Tensor [source]
Forward pass for the 4-bit quantized convolution.
- Parameters:
ctx (torch.autograd.function.FunctionCtx) – Context object to save information for backward computation.
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Weight tensor.
scale_a (torch.Tensor) – Scaling factor for input quantization.
scale_w (torch.Tensor) – Scaling factor for weight quantization.
is_train (bool) – Flag indicating if the model is in training mode.
kernel_size (int) – Size of the convolution kernel.
stride (int) – Stride of the convolution.
padding (int) – Padding added to both sides of the input.
dilation (int) – Spacing between kernel elements.
- Returns:
The result of the 4-bit quantized convolution operation.
- Return type:
torch.Tensor