bitorch_engine.utils.convert.replace_layers

bitorch_engine.utils.convert.replace_layers(module: Module, names_to_replace: Iterable[str], class_: Type, replace_fn: Callable[[Module], Module], parent_name: str = '')[source]

This function replaces all layers (recursively) within the given modules, whose names are included in the given list. It requires a function replace_fn that constructs the replacement object for each given layer. Creating a list of layers can be done with the collect_layers function.

Parameters:
  • module – the (sub-)network module, in which the layers should be replaced

  • names_to_replace – the names of all layers to be replaced

  • class – the replacement class

  • replace_fn – function which creates an instance of the replacement class

  • parent_name – the name of the parent (usually empty when called directly)

Returns: