Source code for xpm_torch.trainers.context

from typing import (
    List,
    Optional,
    TYPE_CHECKING,
)
from pathlib import Path
import os
import json
import shutil
import torch
import logging

from torch.utils.tensorboard.writer import SummaryWriter
from shutil import rmtree
from contextlib import contextmanager
import lightning as L

from experimaestro.utils import cleanupdir
from xpm_torch.context import InitializationHook, Hook, Context
from xpm_torch.metrics import Metric, Metrics
from xpm_torch.losses import Loss


if TYPE_CHECKING:
    from xpm_torch.learner import ScheduledOptimizer, Module
    from xpm_torch.trainers import Trainer

logger = logging.getLogger("xpm_torch.trainer.context")


[docs] class TrainState: """Represents a training state for serialization""" MODEL_PATH = "model.pth" MODEL_DIR = "model" epoch: int """The epoch""" steps: int """The number of steps (each epoch is composed of sptes)""" def __init__( self, model: "Module", trainer: "Trainer", optimizer: "ScheduledOptimizer", epoch=0, steps=0, ): # Initialize the state self.model = model self.trainer = trainer self.optimizer = optimizer self.epoch = epoch self.steps = steps # Was it loaded from disk? self.cached = False # Was it saved? self.path = None def copy(self): return TrainState(self.model, self.trainer, self.optimizer, **self.state_dict()) def state_dict(self): return { "epoch": self.epoch, "steps": self.steps, } @property def step(self): """Returns the step for logging (number of steps)""" return self.steps def load_state_dict(self, state): self.epoch = state.get("epoch", 0) self.steps = state.get("steps", 0)
[docs] def save(self, path): """Save the state""" cleanupdir(path) with (path / "info.json").open("wt") as fp: json.dump(self.state_dict(), fp) model_dir = path / self.MODEL_DIR model_dir.mkdir() self.model.save_model(model_dir) torch.save(self.trainer.state_dict(), path / "trainer.pth") torch.save(self.optimizer.state_dict(), path / "optimizer.pth") self.path = path
[docs] def load(self, path, onlyinfo=False): """Loads the state from disk""" if not onlyinfo: model_dir = path / self.MODEL_DIR if model_dir.exists(): self.model.load_model(model_dir) else: # Backward compat: load from legacy model.pth self.model.load_state_dict( torch.load(path / self.MODEL_PATH, map_location="cpu", weights_only=True) ) self.trainer.load_state_dict(torch.load(path / "trainer.pth")) self.optimizer.load_state_dict(torch.load(path / "optimizer.pth")) with (path / "info.json").open("rt") as fp: self.load_state_dict(json.load(fp)) self.path = path self.cached = True
def copy_model(self, path: Path): assert self.path is not None # Copy info.json os.link(self.path / "info.json", path / "info.json") # Copy model dir or legacy model.pth model_dir = self.path / self.MODEL_DIR if model_dir.exists(): shutil.copytree(model_dir, path / self.MODEL_DIR) else: os.link(self.path / self.MODEL_PATH, path / self.MODEL_PATH)
[docs] class TrainingHook(Hook): """Base class for all training hooks""" pass
[docs] class ValidationHook(Hook): """Base class for all the validation hooks""" def after(self, state: "TrainerContext"): """Called after a validation step""" def before(self, state: "TrainerContext"): """Called before a validation step"""
[docs] class StepTrainingHook(TrainingHook): """Base class for hooks called at each step (before/after)""" def after(self, state: "TrainerContext"): """Called after a training step""" def before(self, state: "TrainerContext"): """Called before a training step"""
[docs] class InitializationTrainingHook(TrainingHook, InitializationHook): """Base class for hooks called at initialization""" def after(self, state: "TrainerContext"): pass def before(self, state: "TrainerContext"): pass
[docs] class TrainerContext(Context): """Contains all the information about the training context for a spefic This object is used when training to provide models and losses' with extra information - as well as the possibility to add regularization losses """ metrics: Optional[Metrics] """Metrics to be reported""" _losses: Optional[List[Loss]] """Regularization losses to be added to the main loss""" _scope: List[str] """Scope for metric names""" PREFIX = "epoch-" def __init__( self, logpath: Path, path: Path, max_epoch: int, steps_per_epoch: int, trainer, model: "Module", optimizer: "ScheduledOptimizer", fabric: L.Fabric, ): super().__init__() self.path = path self.logpath = logpath self.max_epoch = max_epoch self.steps_per_epoch = steps_per_epoch self._writer = None self._scope = [] self._losses = None self.fabric = fabric self.state = TrainState(model, trainer, optimizer) @property def writer(self): """Returns a tensorboard writer by default, purges the entries beside the current epoch """ if self._writer is None: self._writer = SummaryWriter(self.logpath, purge_step=self.state.step) return self._writer @property def epoch(self): return self.state.epoch @property def steps(self): return self.state.steps def nextepoch(self): self.oldstate = self.state self.state = self.oldstate.copy() self.state.epoch += 1 def nextbatch(self): self.state.steps += 1
[docs] def load_bestcheckpoint(self, max_epoch: int): """Try to find the best checkpoint to load (the highest lower than the epoch target)""" # Find all the potential epochs epochs = [] for f in self.path.glob(f"{TrainerContext.PREFIX}*"): epoch = int(f.name[len(TrainerContext.PREFIX) :]) if epoch <= max_epoch: epochs.append(epoch) epochs.sort(reverse=True) # Try to load the first one for epoch in epochs: logger.info("Loading from checkpoint at epoch %d", epoch) path = self.path / f"{TrainerContext.PREFIX}{epoch:08d}" try: self.state.load(path) return True except NotImplementedError: logger.error("Not removing checkpoint") raise except Exception: rmtree(path) logger.exception("Cannot load from epoch %d", epoch) return False
@staticmethod def get_checkpoint_path(checkpointspath: Path, epoch: int) -> Path: return checkpointspath / f"{TrainerContext.PREFIX}{epoch:08d}" def save_checkpoint(self): # Serialize path = TrainerContext.get_checkpoint_path(self.path, self.epoch) if self.state.path is not None: # No need to save twice return # Save self.state.save(path) # Cleanup if needed if self.oldstate and self.oldstate.path: try: rmtree(self.oldstate.path) except OSError as e: # We continue the learning process in those cases logger.error("OS Error while trying to remove directory %s", e) self.oldstate = None
[docs] def copy(self, path: Path): """Copy the state into another folder""" if self.state.path is None: self.save_checkpoint() trainpath = self.state.path assert trainpath is not None if path: cleanupdir(path) self.state.copy_model(path)
def add_loss(self, loss: Loss): assert ( self._losses is not None ), "This should be called in the context where loss is computed" self._losses.append(loss) @contextmanager def losses(self): previous = self._losses try: self._losses = [] yield self._losses finally: self._losses = previous @contextmanager def step(self, metrics): try: self.state.optimizer.zero_grad() self.metrics = Metrics() yield self.metrics self.state.optimizer.optimizer_step(self) self.state.optimizer.scheduler_step(self) metrics.merge(self.metrics) finally: self.metrics = None def backward(self, loss: torch.Tensor): if self.fabric: self.fabric.backward(loss) else: loss.backward()
[docs] def add_metric(self, metric: Metric): """ add a metric to be reported at the end of the step (e.g., for logging in tensorboard) :param metric: The metric to be added :type metric: Metric """ assert self.metrics is not None, "Not within an optimization step" if self._scope: metric.key = "/".join(s for s in self._scope if s) + "/" + metric.key self.metrics.add(metric)
@contextmanager def scope(self, name: str): try: self._scope.append(name) yield self finally: self._scope.pop()