Source code for bitorch_engine.utils.mlx_extension

from pathlib import Path

from torch.utils.cpp_extension import CppExtension

from bitorch_engine.extensions import EXTENSION_PREFIX

from .mlx_path import get_mlx_include_path, get_mlx_lib_path


[docs] def get_mlx_extension(root_path: Path, relative_name: str, relative_sources) -> CppExtension: """ Creates and returns a CppExtension for compiling a C++ extension with MLX library support. This function is designed to simplify the configuration of a C++ extension that depends on the MLX library by automatically setting up include directories, library directories, and other necessary compilation and runtime settings. Parameters: root_path (Path): The root directory path where the C++ source files are located. This path is used to resolve the full paths to the source files specified in `relative_sources`. relative_name (str): A relative name for the extension. This name is prefixed with a predefined prefix and used as the extension's name. relative_sources (Iterable[str]): A list or iterable of relative paths to the C++ source files, relative to `root_path`. These source files constitute the extension being compiled. Returns: CppExtension: An instance of CppExtension configured with paths to include directories, library directories, and other settings needed to compile and link the extension with the MLX library. """ include_path = get_mlx_include_path() lib_path = get_mlx_lib_path() return CppExtension( name=EXTENSION_PREFIX + relative_name, sources=[str(root_path / rel_path) for rel_path in relative_sources], include_dirs=[include_path], library_dirs=[lib_path], # To find mlx during compilation runtime_library_dirs=[lib_path], # To find mlx during runtime libraries=['mlx'], extra_compile_args=['-std=c++17'], )