Source code for xpm_torch.huggingface

"""HuggingFace Hub integration for xpm-torch.

Provides :class:`TorchHFHub` for exporting ModuleLoaders to the Hub
(calls ``write_hub_extras`` and ``hub_readme_sections``), plus utility
functions for cache checking and downloading.
"""

import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Optional, Union, Type, Dict, TypeVar

from experimaestro.huggingface import ExperimaestroHFHub
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import EntryNotFoundError, RepositoryNotFoundError
from experimaestro.core.context import SerializedPath
from experimaestro.core.objects import ConfigInformation

from xpm_torch.module import Module, ModuleLoader, assemble_readme_sections

import logging
logger = logging.getLogger(__name__)

# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="TorchHFHub")

[docs] @lru_cache def prepare_hf_model(model_id: str) -> bool: """Check if model and tokenizer are in cache, if not, download all necessary files. Args: model_id: The ID of the model to check. Returns: True if both model and tokenizer are in cache or after downloading, False if download fails. """ model_in_cache = check_hf_cache(model_id, is_model=True) tokenizer_in_cache = check_hf_cache(model_id, is_model=False) logger.info("Preparing model %s ...", model_id) if model_in_cache and tokenizer_in_cache: logger.info("Model and tokenizer for %s are already in cache.", model_id) return True logger.info("Downloading missing files for %s...", model_id) try: if not model_in_cache: snapshot_download(repo_id=model_id) if not tokenizer_in_cache: snapshot_download(repo_id=model_id) logger.info("Successfully downloaded missing files for %s.", model_id) return True except Exception as e: logger.error("Failed to download files for %s: %s", model_id, e) return False
[docs] def check_hf_cache(model_id: str, is_model: bool = True) -> bool: """Check if the model or tokenizer is already downloaded in the cache. Args: model_id: The ID of the model or tokenizer to check. is_model: If True, checks for model files. If False, checks for tokenizer files. Returns: True if the model or tokenizer is already downloaded, False otherwise. """ model_files = [ "config.json", "pytorch_model.bin", "tf_model.h5", "model.safetensors", ] tokenizer_files = [ "tokenizer.json", "tokenizer_config.json", "vocab.json", "merges.txt", ] files_to_check = model_files if is_model else tokenizer_files for filename in files_to_check: try: hf_hub_download( repo_id=model_id, filename=filename, local_files_only=True ) return True except (EntryNotFoundError, RepositoryNotFoundError): continue return False
[docs] def get_hf_config(repo_id: str) -> dict: """Pull config.json from HF Hub without importing transformers.""" config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path) as f: return json.load(f)
[docs] def download_huggingface_model( model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, cache_dir: Optional[str] = None, ) -> Path: """Download a model file from HuggingFace Hub, using local cache if available. Args: model_id: The model ID on HuggingFace Hub (e.g., "bert-base-uncased"). filename: The specific file name to download. subfolder: A subfolder in the model repository where the file is located. revision: The specific model version to use. cache_dir: Path to the folder where cached files are stored. Returns: The local path to the downloaded (or already cached) model file. Raises: ValueError: If the model file cannot be found locally or downloaded. """ # First, try to load from local cache only try: local_path = hf_hub_download( repo_id=model_id, filename=filename, subfolder=subfolder, revision=revision, cache_dir=cache_dir, local_files_only=True, ) logger.info( "Model file '%s' for '%s' found in local cache: %s", filename, model_id, local_path, ) return Path(local_path) except ValueError: logger.info( "Model file '%s' for '%s' not found in local cache. " "Attempting download from HuggingFace Hub.", filename, model_id, ) # If not in local cache, try to download from the hub try: local_path = hf_hub_download( repo_id=model_id, filename=filename, subfolder=subfolder, revision=revision, cache_dir=cache_dir, local_files_only=False, ) logger.info( "Successfully downloaded model file '%s' for '%s' to: %s", filename, model_id, local_path, ) return Path(local_path) except ValueError as e: logger.error( "Failed to download model file '%s' for '%s' from HuggingFace Hub: %s", filename, model_id, e, ) raise
[docs] class TorchHFHub(ExperimaestroHFHub): """HF Hub integration for xpm-torch ModuleLoaders. Extends :class:`~experimaestro.huggingface.ExperimaestroHFHub` to call :meth:`~xpm_torch.module.ModuleLoader.write_hub_extras` after serialization and build the README. :meth:`~xpm_torch.module.ModuleLoader.hub_readme_sections`. Subclass this (e.g. ``XPMIRHFHub``) to add library-specific README sections, TensorBoard logs, etc. """ def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) super()._save_pretrained(save_directory) # Call ModuleLoader hub hooks self.config.write_hub_extras(save_directory) # Build README from loader sections loader_sections = self.config.hub_readme_sections() base_sections = self._readme_base_sections() if base_sections or loader_sections: readme = assemble_readme_sections(base_sections, loader_sections) (save_directory / "README.md").write_text(readme) @classmethod def _from_pretrained( cls, model_id, revision=None, cache_dir=None, force_download=False, proxies=None, resume_download=None, local_files_only=False, token=None, **model_kwargs ) -> Module: """ This overrides `ExperimaestroHFHub._from_pretrained` outputs directly the model instance instead of the loader. """ if os.path.isdir(model_id): save_directory = Path(model_id) def loader_path(path: Path): if isinstance(path, SerializedPath): path = path.path else: path = Path(path) return save_directory / path else: def loader_path(s_path: Union[Path, str, SerializedPath]) -> Path: if not isinstance(s_path, SerializedPath): s_path = SerializedPath(Path(s_path), False) path = s_path.path # Folder if s_path.is_folder: hf_path = snapshot_download( repo_id=model_id, allow_patterns=f"{s_path.path}/**", revision=revision, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return Path(hf_path) / path hf_path = Path( hf_hub_download( repo_id=model_id, filename=str(path), revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) ) return hf_path loader: ModuleLoader = ConfigInformation.deserialize( loader_path, as_instance=True, partial_loading=True, definition_filename=cls.definition_filename, ) #execute the moduleLoader Instance -> loads the model loader.execute() return loader.model
[docs] @classmethod def pretrained_loader( cls: Type[T], pretrained_model_name_or_path: Union[str, Path], *, force_download: bool = False, resume_download: Optional[bool] = None, proxies: Optional[Dict] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, revision: Optional[str] = None, as_instance: bool = False, #specific to this Class **model_kwargs, ) -> ModuleLoader: """ Download a model _loader_ from the Huggingface Hub. """ # Call parent's _from_pretrained directly to avoid the overridden version return ExperimaestroHFHub.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, as_instance=as_instance, # pass to super but it will be ignored **model_kwargs, )
def _readme_base_sections(self): """Return base README sections. Override in subclasses to add library-specific content (description, usage examples, results).""" return []