Source code for bitorch_engine.layers.qlinear.binary.cpp.layer

from typing import Any

import torch
from bitorch import RuntimeMode
from bitorch.layers import QLinearBase
from bitorch.layers.extensions import LayerRecipe
from bitorch.layers.register import QLinearImplementation
from torch.autograd import Function

from bitorch_engine.utils.safe_import import import_extension
from ..binary_implementation import BinaryLinearImplementationMixin
from ..layer import BinaryLinearBase
from bitorch_engine.utils.model_helper import flatten_x, unflatten_x

binary_linear_cpp = import_extension("binary_linear_cpp")


[docs] class BinaryLinearForward(Function): """ A custom autograd function for performing forward pass of binary linear layer. This function uses a custom C++ backend for efficient computation. Args: ctx (torch.autograd.function.FunctionCtx): The context for storing information for backward computation. input (torch.Tensor): The input tensor. weights (torch.Tensor): The binary weights tensor. m (int): The batch size. n (int): The number of output features. k (int): The number of input features. Returns: torch.Tensor: The output tensor after applying the binary linear transformation. """
[docs] @staticmethod def forward(ctx, input: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int) -> torch.Tensor: input, shape = flatten_x(input) output = binary_linear_cpp.forward(input, weights, m, n, k) output = unflatten_x(output, shape) return output
[docs] @QLinearImplementation(RuntimeMode.CPU) class BinaryLinearCPP(BinaryLinearImplementationMixin, BinaryLinearBase): """ A class representing the binary linear layer implemented in C++ for CPU runtime mode. Inherits from BinaryLinearBase and mixes in BinaryLinearImplementationMixin for common functionality. This class supports creating a clone of itself from a given LayerRecipe, allowing for easy replication and modification of layer parameters. """
[docs] @classmethod def create_clone_from(cls, recipe: LayerRecipe) -> Any: """ Creates a clone of this layer based on the provided LayerRecipe. Args: recipe (LayerRecipe): The recipe containing the parameters for the clone. Returns: Any: A new instance of this class with parameters derived from the recipe. """ args = QLinearBase.get_args_as_kwargs(recipe) input_features, output_features = args["in_features"], args["out_features"] input_features //= 8 new_layer = cls(input_features, output_features) new_layer.set_weight_data(recipe.layer.weight.data) new_layer.generate_quantized_weight(qweight_only=True) return new_layer
[docs] def __init__( self, input_features: int, out_features: int, device: torch.device = None, ) -> None: """ Initializes the BinaryLinearCPP layer. Args: input_features (int): The number of input features (divided by 8 for binary). out_features (int): The number of output features. device (torch.device, optional): The device on which to perform computations. """ super().__init__(input_features, out_features, device)
[docs] def prepare_params(self) -> None: """ Prepares and initializes the model parameters for training. One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function. """ pass
[docs] def generate_quantized_weight(self, qweight_only: bool = False) -> None: """ Generates the quantized weight matrix for this layer and optionally clears the original weight. Args: qweight_only (bool, optional): If True, the original weight matrix is cleared to save memory. """ # Generate packed weight using custom C++ function self.qweight = binary_linear_cpp.w_pack( self.weight, # Original weight self.output_features, # n self.input_features, # k ) if qweight_only: self.weight = None # Clear the original weight matrix if specified
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines the forward pass of the binary linear layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying the binary linear transformation. """ # Check input validity self._check_forward(x) # pass m, n, k m = x.size(dim=0) # batch size k = x.size(dim=1) # input features n = self.output_features # output features return BinaryLinearForward.apply(x, self.opt_weight, m, n, k)