bitorch_engine.utils.model_helper.binary_matmul_forward_post_processing

bitorch_engine.utils.model_helper.binary_matmul_forward_post_processing(tensor: Tensor, shape_pre: list, x_pad_sec_last: int, y_pad_sec_last: int, k: int) Tensor[source]

Post-processes the output tensor of a binary matrix multiplication operation.

This function performs several post-processing steps on the result of a binary matrix multiplication, including truncating any padded elements added during the operation, reshaping the tensor back to its original dimensions with additional specified dimensions, and converting the binary data back to its original data domain.

Args: - tensor (torch.Tensor): The output tensor from a binary matrix multiplication to be post-processed. - shape_pre (list): The original shape of the tensor before the binary matrix multiplication, which the output tensor will be reshaped to, with the last two dimensions replaced by the actual last two dimensions of the post-processed tensor. - x_pad_sec_last (int): The number of padded elements added to the second to last dimension of the tensor during the binary matrix multiplication. These will be removed. - y_pad_sec_last (int): The number of padded elements added to the last dimension of the tensor during the binary matrix multiplication. These will be removed. - k (int): A constant used to convert the binary data in the tensor back to its original data domain. The conversion formula applied is k - 2 * tensor.

Returns: - torch.Tensor: The post-processed tensor, reshaped to its original dimensions with specified adjustments and converted back to its original data domain.

Note: - This function is specifically designed for tensors resulting from binary matrix multiplication operations that involve padding and require post-processing to revert to their original format and domain.