Source code for bitorch_engine.utils.cuda_extension

import os
import subprocess
import warnings
from pathlib import Path
from typing import Dict, Any

from torch.utils.cpp_extension import CUDAExtension

from bitorch_engine.extensions import EXTENSION_PREFIX


SUPPORTED_CUDA_ARCHS = ["sm_80", "sm_75"]


[docs] def get_cuda_arch(): cuda_arch = os.environ.get("BIE_CUDA_ARCH", "sm_80") if cuda_arch not in SUPPORTED_CUDA_ARCHS: warnings.warn(f"Warning: BIE_CUDA_ARCH={cuda_arch} may not be supported yet.") return cuda_arch
[docs] def get_cuda_extension(root_path: Path, relative_name: str, relative_sources) -> Any: if os.environ.get("BIE_BUILD_SEPARATE_CUDA_ARCH", "false") == "true": relative_name = f"{relative_name}-{get_cuda_arch()}" return CUDAExtension( name=EXTENSION_PREFIX + relative_name, sources=[str(root_path / rel_path) for rel_path in relative_sources], **get_kwargs() )
[docs] def gcc_version(): """ Determines the version of GCC (GNU Compiler Collection) installed on the system. This function executes the 'gcc --version' command using subprocess.run and parses the output to extract the GCC version number. If GCC is not found or an error occurs during parsing, it returns a default version number of 0.0.0. The function checks if the output contains the string "clang" to identify if clang is masquerading as gcc, in which case it also returns 0.0.0, assuming GCC is not installed. Returns: tuple: A tuple containing three integers (major, minor, patch) representing the version of GCC. Returns (0, 0, 0) if GCC is not found or an error occurs. """ output = subprocess.run(['gcc', '--version'], check=True, capture_output=True, text=True) if output.returncode > 0 or "clang" in output.stdout: return 0, 0, 0 first_line = output.stdout.split("\n")[0] try: version = first_line.split(" ")[-1] major, minor, patch = list(map(int, version.split("."))) return major, minor, patch except: return 0, 0, 0
[docs] def get_kwargs() -> Dict[str, Any]: """ Generates keyword arguments for compilation based on the GCC version and CUDA architecture. This function dynamically constructs a dictionary of extra compilation arguments for both C++ and CUDA (nvcc) compilers. It includes flags to suppress deprecated declarations warnings, specify the OpenMP library path, and set the CUDA architecture. Additionally, it conditionally adjusts the nvcc host compiler to GCC 11 if the detected GCC version is greater than 11. Returns: Dict[str, Any]: A dictionary containing the 'extra_compile_args' key with nested dictionaries for 'cxx' and 'nvcc' compilers specifying their respective extra compilation arguments. """ # Retrieve the current GCC version to determine compatibility and required flags major, minor, patch = gcc_version() # Initialize the base kwargs dict with common compile arguments for cxx and nvcc kwargs = { "extra_compile_args": { "cxx": [ "-Wno-deprecated-declarations", # Suppress warnings for deprecated declarations "-L/usr/lib/gcc/x86_64-pc-linux-gnu/10.3.0/libgomp.so", # Specify libgomp library path for OpenMP "-fopenmp", # Enable OpenMP support ], "nvcc": [ "-Xcompiler", "-fopenmp", # Pass -fopenmp to the host compiler via nvcc f"-arch={get_cuda_arch()}", # Set CUDA architecture, default to 'sm_80' "-DARCH_SM_75" if get_cuda_arch() == 'sm_75' else "-DARCH_SM_80", # set architecture macro ], }, } # If the GCC version is greater than 11, adjust the nvcc host compiler settings if major > 11: print("Using GCC 11 host compiler for nvcc.") # Specify GCC 11 as the host compiler for nvcc: kwargs["extra_compile_args"]["nvcc"].append("-ccbin=/usr/bin/gcc-11") return kwargs