bitorch_engine.utils.convert.quantize_linear_with_mpq_linear_cuda

bitorch_engine.utils.convert.quantize_linear_with_mpq_linear_cuda(module: Module, names_to_replace: Iterable[str], mpq_strategy: str | None = None, dtype: dtype = torch.bfloat16, parent_name: str = '')[source]

Replace all layers contained in names_to_replace within the given module with MPQLinearCuda layers. :param module: the module which contains Linear layers :param names_to_replace: the list of layer names to be replaced :param mpq_strategy: which MPQ strategy to use, see get_mpq_config :param dtype: the dtype of the new module :param parent_name: the name of the parent (usually empty when called directly) :return: the list of names of layers which were replaced