Source code for bitorch_engine.layers.qconv.binary.cpp.layer

import torch
from torch.autograd import Function

from bitorch_engine.utils.safe_import import import_extension

binary_conv_cpp = import_extension("binary_conv_cpp")


from bitorch_engine.utils.quant_operators import get_binary_row
from ..layer import BinaryConv2dBase

[docs] class BinaryConv2dForward(Function): """ A custom autograd function to perform forward pass of a 2D binary convolution. This class implements a static method `forward` to carry out the convolution operation using binary weights and activations. The operation is performed using a custom C++ backend for efficiency. Attributes: - No class-level attributes. Methods: - forward: Performs the forward pass of the binary convolution. """
[docs] @staticmethod def forward(ctx, activations: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int, kernel_size: int, stride: int, padding: int, dilation: int, output_edge: int) -> torch.Tensor: """ Forward pass for the 2D binary convolution. Utilizes a C++ backend implemented in `binary_conv_cpp.forward` to perform the operation. This method is statically defined and automatically integrated with PyTorch's autograd mechanism. Parameters: - ctx (torch.autograd.function.BackwardContext): Context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the `save_for_backward` method. - activations (Tensor): The input feature map or activation tensor. - weights (Tensor): The binary weights tensor. - m, n, k (int): Dimensions of the input, specifically: - m: The number of output channels. - n: The number of input channels. - k: The spatial size of the output feature map. - kernel_size (int or tuple): Size of the conv kernel. - stride (int or tuple): Stride of the convolution. - padding (int or tuple): Zero-padding added to both sides of the input. - dilation (int or tuple): Spacing between kernel elements. - output_edge (int): The size of the output edge to ensure the output dimension matches expectations. Returns: - Tensor: The output feature map resulting from the binary convolution operation. Note: This method is part of the forward pass and needs to be paired with a corresponding backward method to enable gradient computation. """ output = binary_conv_cpp.forward(activations, weights, m, n, k, kernel_size, stride, padding, dilation, output_edge) return output
[docs] class BinaryConv2dCPP(BinaryConv2dBase): """ This class implements a binary convolutional layer in PyTorch, specifically optimized with C++ extensions. It inherits from BinaryConv2dBase to leverage common binary convolution functionalities with added optimizations for efficient computation. Attributes: bits_binary_word (int): Defines the size of the binary word, defaulting to 8 bits. """
[docs] def __init__(self, *args, **kwargs): """ Initializes the BinaryConv2dCPP layer with the given arguments, which are forwarded to the base class. Additionally, it sets up the binary word size for quantization. Args: *args: Variable length argument list to be passed to the BinaryConv2dBase class. **kwargs: Arbitrary keyword arguments to be passed to the BinaryConv2dBase class. """ super(BinaryConv2dCPP, self).__init__(*args, **kwargs) self.bits_binary_word = 8
[docs] def prepare_params(self) -> None: """ Prepares and initializes the model parameters for training. One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function. """ pass
[docs] def generate_quantized_weight(self, qweight_only: bool = False) -> None: """ Generates and stores quantized weights based on the current weights of the layer, utilizing a binary quantization method. Quantized weights are stored as a torch.nn.Parameter but are not set to require gradients. Args: qweight_only (bool): If True, the original weights are discarded to save memory. Defaults to False. """ w_size = self.out_channels * self.in_channels/self.bits_binary_word * self.kernel_size * self.kernel_size self.qweight = torch.nn.Parameter( get_binary_row(self.weight.reshape(-1, ), torch.empty(int(w_size), dtype=torch.uint8), w_size * self.bits_binary_word, self.bits_binary_word), requires_grad=False ) if qweight_only: self.weight = None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines the forward pass for the binary convolution operation using the quantized weights. Args: x (torch.Tensor): The input tensor for the convolution operation with shape (N, C_in, H, W), where N is the batch size, C_in is the number of channels, and H, W are the height and width of the input tensor. Returns: torch.Tensor: The output tensor of the convolution operation with shape determined by the layer's attributes and the input dimensions. """ self._check_forward(x) # pass m, n, k m = self.out_channels # number of output channel k = x.size(dim=1) * self.kernel_size * self.kernel_size; # number of input channels * kernel size # (Image_w – filter_w + 2*pad_w) / stride + 1 output_edge = int((x.size(dim=2) - self.kernel_size + 2 * self.padding) / self.stride + 1) n = output_edge * output_edge # number of pixels of output images per channel return BinaryConv2dForward.apply(x, self.opt_weight, m, n, k, self.kernel_size, self.stride, self.padding, self.dilation, output_edge)