Source code for xpm_torch.module

from dataclasses import dataclass
from typing import (
    List,
    Dict,
    Optional,
)
from pathlib import Path
import torch
import logging
import torch.nn as nn
from experimaestro import (
    field,
    Config,
    DataPath,
    Param,
    SerializationLWTask,
)
from xpm_torch.utils.utils import Initializable

logger = logging.getLogger(__name__)


def initialized(method):
    """Decorator that ensures ``initialize()`` is called before the first
    invocation, then replaces itself with the original method so subsequent
    calls have zero overhead.

    Usage::

        class MyModule(Module):
            @initialized
            def forward(self, x):
                ...
    """

    def wrapper(self, *args, **kwargs):
        if not self._initialized:
            self.initialize()
        # Replace the wrapper with the unwrapped method on this instance
        bound = method.__get__(self, type(self))
        setattr(self, method.__name__, bound)
        return method(self, *args, **kwargs)

    # Preserve the original name/docstring for introspection
    wrapper.__name__ = method.__name__
    wrapper.__doc__ = method.__doc__
    return wrapper


[docs] class Module(Config, Initializable, nn.Module): """Base class for all modules containing parameters""" def __init__(self): Initializable.__init__(self) torch.nn.Module.__init__(self) def __initialize__(self): """Initialize a module (structure only, no weight loading)""" pass def __call__(self, *args, **kwargs): return torch.nn.Module.__call__(self, *args, **kwargs) @property def device(self): return next(self.parameters()).device def count_parameters(self): """Count the number of parameters in the model""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def get_forward_methods(self) -> list: """Returns the list of forward methods for this scorer. Needed to set up Fabric with multiple forward methods. By default, it is just `forward`, but it can be extended to support multiple forward methods (e.g. for different scoring strategies)""" return [] def save_model(self, path: Path): """Save model parameters to a directory or a file using safetensors. If the path ends with .safetensors, it saves directly to that file. Otherwise, it creates a directory and saves as 'model.safetensors' inside. """ from safetensors.torch import save_file if path.suffix == ".safetensors": path.parent.mkdir(parents=True, exist_ok=True) save_file(self.state_dict(), str(path)) else: path.mkdir(parents=True, exist_ok=True) save_file(self.state_dict(), str(path / "model.safetensors")) def load_model(self, path: Path): """Load model parameters from a directory or a file.""" from safetensors.torch import load_file if path.is_file(): self.load_state_dict(load_file(str(path))) elif (path / "model.safetensors").exists(): self.load_state_dict(load_file(str(path / "model.safetensors"))) else: raise FileNotFoundError(f"Could not find model weights at {path}") def loader_config( self, path: Path, *, settings: Optional[Config] = None ) -> "ModuleLoader": """Returns a ModuleLoader config that knows how to load this model from path. The default returns a :class:`SimpleModuleLoader` with a single ``path`` DataPath. Subclasses override to return loaders with different DataPath layouts (e.g. separate encoder paths). Args: path: The base checkpoint path containing the model/ directory. settings: Optional metadata to attach to the loader. """ return SimpleModuleLoader.C(value=self, path=path, settings=settings) def export_action(self, loader: "ModuleLoader", **kwargs): """Returns an ExportAction config for this model. Subclasses override to return library-specific actions (e.g. with xpmir README sections, paper metadata). Args: loader: The ModuleLoader to export. **kwargs: Extra params passed to the action (e.g. default_name). """ from xpm_torch.actions import ExportAction return ExportAction.C(loader=loader, **kwargs) def to(self, *args, **kwargs): return torch.nn.Module.to(self, *args, **kwargs)
[docs] class ModuleList(Module, Initializable): """Groups different models together, to be used within the Learner""" sub_modules: Param[List[Module]] def __post_init__(self): # Register sub-modules for ix, sub_module in enumerate(self.sub_modules): self.add_module(str(ix), sub_module) def __initialize__(self): for module in self.sub_modules: module.initialize() def __call__(self, *args, **kwargs): raise AssertionError("This module cannot be used as such") def to(self, *args, **kwargs): return torch.nn.Module.to(self, *args, **kwargs)
[docs] @dataclass class ReadmeSection: """A named section for the HF Hub README. Sections are assembled in order, with optional ``before``/``after`` constraints for positioning relative to other sections. """ key: str """Unique identifier for this section.""" content: str """Markdown content of this section.""" before: Optional[str] = None """Insert this section before the section with this key.""" after: Optional[str] = None """Insert this section after the section with this key."""
[docs] def assemble_readme_sections( base: List[ReadmeSection], extra: List[ReadmeSection] ) -> str: """Merge extra sections into base using before/after constraints, then concatenate all contents.""" sections = list(base) for s in extra: if s.before: idx = next( (i for i, b in enumerate(sections) if b.key == s.before), len(sections), ) sections.insert(idx, s) elif s.after: idx = next( (i for i, b in enumerate(sections) if b.key == s.after), len(sections) - 1, ) sections.insert(idx + 1, s) else: sections.append(s) return "\n".join(s.content for s in sections)
[docs] class ModuleLoader(SerializationLWTask): """Base class for loading a model from a checkpoint. Subclasses define their own ``DataPath`` fields to specify where model files are stored (e.g. a single ``path`` or separate ``encoder_path`` and ``query_encoder_path``). The optional :attr:`settings` field carries opaque metadata (e.g. validation key, checkpoint epoch) that distinguishes loaders with the same model and path. Override :meth:`write_hub_extras` and :meth:`hub_readme_sections` to customize what gets written when the model is exported to HuggingFace Hub. The model config is accessible via :attr:`model` (alias for ``value``). """ settings: Param[Optional[Config]] = field(default=None, ignore_default=True) """Optional metadata (validation info, checkpoint epoch, etc.)""" @property def model(self): """The model config (alias for ``value``).""" return self.value def write_hub_extras(self, save_directory: Path): """Write additional files when exporting to HuggingFace Hub. Called by ``ExperimaestroHFHub._save_pretrained`` after the main serialization. Override in subclasses to write format-specific files (e.g. sentence-transformers configs). Args: save_directory: The Hub export directory. """ if hasattr(self.model, "write_hub_extras"): self.model.write_hub_extras(save_directory) def hub_readme_sections(self) -> List[ReadmeSection]: """Return additional sections for the HF Hub README. Each :class:`ReadmeSection` has a key and content, plus optional ``before``/``after`` constraints for positioning relative to the base sections (``frontmatter``, ``description``, ``usage``, ``results``). Override in subclasses to provide model-specific content. """ if hasattr(self.model, "hub_readme_sections"): return self.model.hub_readme_sections() return [] def execute(self): raise NotImplementedError("Subclasses must implement execute()")
[docs] class SimpleModuleLoader(ModuleLoader): """Default ModuleLoader with a single ``path`` DataPath. Loads model weights from a checkpoint directory containing either a ``model/`` subdirectory (safetensors) or a ``model.pth`` file. """ path: DataPath """Path to the checkpoint directory""" def __xpm_serialize__(self, context): """Serialize the path directly to model.safetensors at the root""" result = {} path = Path(self.path) # Ensure path exists and is not the current directory (resolved from empty string) if not path.exists() or path.resolve() == Path.cwd().resolve(): raise FileNotFoundError(f"Cannot serialize SimpleModuleLoader: path '{self.path}' does not exist or is the current directory") # If it's a directory, point to the file inside it so it gets # serialized as a file instead of a directory if path.is_dir() and (path / "model.safetensors").exists(): path = path / "model.safetensors" # Serialize the 'path' field under the name "model.safetensors" result["path"] = context.serialize( context.var_path + ["model.safetensors"], path, self ) return result def execute(self): """Loads the model from disk using the given serialization path""" self.value.initialize() path = Path(self.path) logger.info("Loading model from disk: %s", path) self.value.load_model(path)
[docs] class ModuleContainer(nn.Module): """A container for Modules, exposing only nn.Module attributes that actually contain state (parameters or buffers). Example:: class MyRetriever(ModuleContainer): def __init__(self): super().__init__() self.encoder = nn.Linear(128, 64) # Has params -> wrapped self.activ = nn.ReLU() # No params -> ignored retriever = MyRetriever() retriever.setup_with_fabric(fabric) """ def __init__(self): super().__init__()
[docs] def get_manageable_modules(self) -> Dict[str, nn.Module]: """ Returns a mapping of attributes that are nn.Modules AND have actual data (params/buffers) to manage. """ manageable = {} # Iterate through immediate children, thanks to nn.Module registering for name, module in self.named_children(): # Check if this specific module or any of its descendants have state has_params = any(p.numel() > 0 for p in module.parameters()) has_buffers = any(b.numel() > 0 for b in module.buffers()) if has_params or has_buffers: manageable[name] = module return manageable
[docs] def setup_with_fabric(self, fabric) -> None: """ Self-identifies which children need Fabric wrapping. """ self.fabric = fabric modules_to_wrap = self.get_manageable_modules() if not modules_to_wrap: logger.debug("No stateful modules found. Skipping Fabric setup.") return for name, module in modules_to_wrap.items(): # Wrap the module and re-assign it wrapped = fabric.setup(module) for method_name in module.get_forward_methods(): wrapped.mark_forward_method(method_name) setattr(self, name, wrapped) logger.info(f"Registered {name} (type: {type(module).__name__}) with Fabric on {fabric.device}")
def find_module_attributes(obj) -> dict: """ Finds all instances of `xpm_torch.Module` in attributes of any object. only looks at immediate attributes, not recursively. """ found_modules = {} # vars(obj) returns the __dict__ of the instance # We use list() to avoid "dictionary changed size" errors if needed for attr_name, attr_value in vars(obj).items(): if isinstance(attr_value, Module): found_modules[attr_name] = attr_value return found_modules