from typing import (
List,
Optional,
TYPE_CHECKING,
)
from pathlib import Path
import os
import json
import shutil
import torch
import logging
from torch.utils.tensorboard.writer import SummaryWriter
from shutil import rmtree
from contextlib import contextmanager
import lightning as L
from experimaestro.utils import cleanupdir
from xpm_torch.context import InitializationHook, Hook, Context
from xpm_torch.metrics import Metric, Metrics
from xpm_torch.losses import Loss
if TYPE_CHECKING:
from xpm_torch.learner import ScheduledOptimizer, Module
from xpm_torch.trainers import Trainer
logger = logging.getLogger("xpm_torch.trainer.context")
[docs]
class TrainState:
"""Represents a training state for serialization"""
MODEL_PATH = "model.pth"
MODEL_DIR = "model"
epoch: int
"""The epoch"""
steps: int
"""The number of steps (each epoch is composed of sptes)"""
def __init__(
self,
model: "Module",
trainer: "Trainer",
optimizer: "ScheduledOptimizer",
epoch=0,
steps=0,
):
# Initialize the state
self.model = model
self.trainer = trainer
self.optimizer = optimizer
self.epoch = epoch
self.steps = steps
# Was it loaded from disk?
self.cached = False
# Was it saved?
self.path = None
def copy(self):
return TrainState(self.model, self.trainer, self.optimizer, **self.state_dict())
def state_dict(self):
return {
"epoch": self.epoch,
"steps": self.steps,
}
@property
def step(self):
"""Returns the step for logging (number of steps)"""
return self.steps
def load_state_dict(self, state):
self.epoch = state.get("epoch", 0)
self.steps = state.get("steps", 0)
[docs]
def save(self, path):
"""Save the state"""
cleanupdir(path)
with (path / "info.json").open("wt") as fp:
json.dump(self.state_dict(), fp)
model_dir = path / self.MODEL_DIR
model_dir.mkdir()
self.model.save_model(model_dir)
torch.save(self.trainer.state_dict(), path / "trainer.pth")
torch.save(self.optimizer.state_dict(), path / "optimizer.pth")
self.path = path
[docs]
def load(self, path, onlyinfo=False):
"""Loads the state from disk"""
if not onlyinfo:
model_dir = path / self.MODEL_DIR
if model_dir.exists():
self.model.load_model(model_dir)
else:
# Backward compat: load from legacy model.pth
self.model.load_state_dict(
torch.load(path / self.MODEL_PATH, map_location="cpu", weights_only=True)
)
self.trainer.load_state_dict(torch.load(path / "trainer.pth"))
self.optimizer.load_state_dict(torch.load(path / "optimizer.pth"))
with (path / "info.json").open("rt") as fp:
self.load_state_dict(json.load(fp))
self.path = path
self.cached = True
def copy_model(self, path: Path):
assert self.path is not None
# Copy info.json
os.link(self.path / "info.json", path / "info.json")
# Copy model dir or legacy model.pth
model_dir = self.path / self.MODEL_DIR
if model_dir.exists():
shutil.copytree(model_dir, path / self.MODEL_DIR)
else:
os.link(self.path / self.MODEL_PATH, path / self.MODEL_PATH)
[docs]
class TrainingHook(Hook):
"""Base class for all training hooks"""
pass
[docs]
class ValidationHook(Hook):
"""Base class for all the validation hooks"""
def after(self, state: "TrainerContext"):
"""Called after a validation step"""
def before(self, state: "TrainerContext"):
"""Called before a validation step"""
[docs]
class StepTrainingHook(TrainingHook):
"""Base class for hooks called at each step (before/after)"""
def after(self, state: "TrainerContext"):
"""Called after a training step"""
def before(self, state: "TrainerContext"):
"""Called before a training step"""
[docs]
class InitializationTrainingHook(TrainingHook, InitializationHook):
"""Base class for hooks called at initialization"""
def after(self, state: "TrainerContext"):
pass
def before(self, state: "TrainerContext"):
pass
[docs]
class TrainerContext(Context):
"""Contains all the information about the training context
for a spefic
This object is used when training to provide models and losses'
with extra information - as well as the possibility to add
regularization losses
"""
metrics: Optional[Metrics]
"""Metrics to be reported"""
_losses: Optional[List[Loss]]
"""Regularization losses to be added to the main loss"""
_scope: List[str]
"""Scope for metric names"""
PREFIX = "epoch-"
def __init__(
self,
logpath: Path,
path: Path,
max_epoch: int,
steps_per_epoch: int,
trainer,
model: "Module",
optimizer: "ScheduledOptimizer",
fabric: L.Fabric,
):
super().__init__()
self.path = path
self.logpath = logpath
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._writer = None
self._scope = []
self._losses = None
self.fabric = fabric
self.state = TrainState(model, trainer, optimizer)
@property
def writer(self):
"""Returns a tensorboard writer
by default, purges the entries beside the current epoch
"""
if self._writer is None:
self._writer = SummaryWriter(self.logpath, purge_step=self.state.step)
return self._writer
@property
def epoch(self):
return self.state.epoch
@property
def steps(self):
return self.state.steps
def nextepoch(self):
self.oldstate = self.state
self.state = self.oldstate.copy()
self.state.epoch += 1
def nextbatch(self):
self.state.steps += 1
[docs]
def load_bestcheckpoint(self, max_epoch: int):
"""Try to find the best checkpoint to load (the highest lower than
the epoch target)"""
# Find all the potential epochs
epochs = []
for f in self.path.glob(f"{TrainerContext.PREFIX}*"):
epoch = int(f.name[len(TrainerContext.PREFIX) :])
if epoch <= max_epoch:
epochs.append(epoch)
epochs.sort(reverse=True)
# Try to load the first one
for epoch in epochs:
logger.info("Loading from checkpoint at epoch %d", epoch)
path = self.path / f"{TrainerContext.PREFIX}{epoch:08d}"
try:
self.state.load(path)
return True
except NotImplementedError:
logger.error("Not removing checkpoint")
raise
except Exception:
rmtree(path)
logger.exception("Cannot load from epoch %d", epoch)
return False
@staticmethod
def get_checkpoint_path(checkpointspath: Path, epoch: int) -> Path:
return checkpointspath / f"{TrainerContext.PREFIX}{epoch:08d}"
def save_checkpoint(self):
# Serialize
path = TrainerContext.get_checkpoint_path(self.path, self.epoch)
if self.state.path is not None:
# No need to save twice
return
# Save
self.state.save(path)
# Cleanup if needed
if self.oldstate and self.oldstate.path:
try:
rmtree(self.oldstate.path)
except OSError as e:
# We continue the learning process in those cases
logger.error("OS Error while trying to remove directory %s", e)
self.oldstate = None
[docs]
def copy(self, path: Path):
"""Copy the state into another folder"""
if self.state.path is None:
self.save_checkpoint()
trainpath = self.state.path
assert trainpath is not None
if path:
cleanupdir(path)
self.state.copy_model(path)
def add_loss(self, loss: Loss):
assert (
self._losses is not None
), "This should be called in the context where loss is computed"
self._losses.append(loss)
@contextmanager
def losses(self):
previous = self._losses
try:
self._losses = []
yield self._losses
finally:
self._losses = previous
@contextmanager
def step(self, metrics):
try:
self.state.optimizer.zero_grad()
self.metrics = Metrics()
yield self.metrics
self.state.optimizer.optimizer_step(self)
self.state.optimizer.scheduler_step(self)
metrics.merge(self.metrics)
finally:
self.metrics = None
def backward(self, loss: torch.Tensor):
if self.fabric:
self.fabric.backward(loss)
else:
loss.backward()
[docs]
def add_metric(self, metric: Metric):
"""
add a metric to be reported at the end of the step (e.g., for logging in tensorboard)
:param metric: The metric to be added
:type metric: Metric
"""
assert self.metrics is not None, "Not within an optimization step"
if self._scope:
metric.key = "/".join(s for s in self._scope if s) + "/" + metric.key
self.metrics.add(metric)
@contextmanager
def scope(self, name: str):
try:
self._scope.append(name)
yield self
finally:
self._scope.pop()