Source code for bitorch_engine.utils.safe_import

import importlib
from typing import Any

from bitorch_engine.extensions import EXTENSION_PREFIX

MESSAGE = """The extension '{}' could not be imported. It is either not yet implemented or was not build correctly.
This message is expected during the build process. If it appears later on, try installing the package again."""


[docs] class ExtensionModulePlaceholder: """ A placeholder class for extension modules. This class serves as a placeholder for dynamically loaded extension modules. It is designed to intercept attribute access and modifications, raising a runtime error if any operation other than setting the initial name is attempted. This behavior ensures that any misuse of the placeholder, such as accessing or modifying attributes that are not supposed to exist in the placeholder state, is promptly identified during development. Attributes: _name (str): The name of the extension module this placeholder represents. Methods: __init__: Initializes the placeholder with the name of the extension module. __getattr__: Intercepts attribute access attempts, raising a RuntimeError. __setattr__: Intercepts attribute modification attempts, allowing only the _name attribute to be set. """
[docs] def __init__(self, name: str) -> None: """ Initializes the ExtensionModulePlaceholder with a specified name. Args: name (str): The name of the extension module this placeholder represents. """ self._name = name
[docs] def __getattr__(self, item: str) -> Any: """ Handles attribute access attempts for the placeholder. This method raises a RuntimeError to indicate that the attempted access is invalid, as the placeholder should not be used for accessing attributes. Args: item (str): The name of the attribute being accessed. Returns: Any: This method does not return but raises RuntimeError instead. Raises: RuntimeError: Indicates an invalid attribute access attempt. """ raise RuntimeError(MESSAGE.format(self._name))
[docs] def __setattr__(self, key: Any, value: Any) -> None: """ Handles attribute modification attempts for the placeholder. This method allows setting the _name attribute only. Any other attempt to modify attributes raises a RuntimeError to indicate that the operation is invalid. Args: key (Any): The name of the attribute to be modified. value (Any): The new value for the attribute. Raises: RuntimeError: Indicates an invalid attribute modification attempt, except for the _name attribute. """ if key == "_name": self.__dict__["_name"] = value return raise RuntimeError(MESSAGE.format(self._name))
[docs] def import_extension(module_name: str, not_yet_implemented: bool = False) -> Any: """ Dynamically imports a Python extension module by name, providing a safe mechanism to handle cases where the module is not yet built or not yet implemented. This function is particularly useful for conditionally importing modules that provide optional functionality or are platform-specific. If the module is marked as not yet implemented (not_yet_implemented=True), or if the module cannot be found during import, the function returns a placeholder object instead of raising an ImportError. This allows the application to continue running and gracefully handle the absence of the module. Args: module_name (str): The name of the module to be imported. The actual module name will be prefixed with a predefined prefix defined in `EXTENSION_PREFIX` to form the full module name. not_yet_implemented (bool, optional): A flag indicating whether the module is known to be not yet implemented. If True, the function immediately returns a placeholder without attempting to import the module. Defaults to False. Returns: Any: An imported module if successful, or an instance of `ExtensionModulePlaceholder` if the module is not implemented or cannot be found. Example: binary_linear_cuda = import_extension("binary_linear_cuda") This example attempts to import a module named "binary_linear_cuda" (prefixed appropriately), returning the module if found, or a placeholder if not found or not implemented. Note: This function prints a warning message to the console if the module cannot be found, informing the user of the issue without interrupting the execution of the program. """ if not_yet_implemented: return ExtensionModulePlaceholder(module_name) try: return importlib.import_module(EXTENSION_PREFIX + module_name) except ModuleNotFoundError: print("Warning:", MESSAGE.format(module_name)) return ExtensionModulePlaceholder(module_name)