import torch
import json
from typing import Dict, Any
from pathlib import Path
from experimaestro import field, Annotated, Param, Meta, pathgenerator
from xpm_torch import SampleIterator
from xpm_torch.batchers import Batcher
from xpm_torch.learner import (
Learner,
LearnerListener,
LearnerListenerStatus,
TrainerContext,
TrainState,
)
from xpm_torch.optim import ModuleLoader
from xpm_torch.trainers import LossTrainer
from xpm_torch.metrics import Metrics, ScalarMetric
import logging
logger = logging.getLogger(__name__)
[docs]
class TrainerValidationLoss(LearnerListener):
"""Generic trainer-based loss validation"""
data: Param[SampleIterator]
"""The dataset to use"""
batcher: Meta[Batcher] = field(default_factory=Batcher.C)
"""How to batch samples together"""
batch_size: Meta[int]
"""Batch size"""
trainer: Param[LossTrainer]
"""The trainer"""
warmup: Param[int] = field(default=-1, ignore_default=True)
"""How many epochs before actually computing the validation loss"""
bestpath: Annotated[Path, pathgenerator("best")]
"""Path to the best checkpoints"""
info: Annotated[Path, pathgenerator("info.json")]
"""Path to the JSON file that contains the metric values at each epoch"""
validation_interval: Param[int] = field(default=1, ignore_default=True)
"""Epochs between each validation"""
early_stop: Param[int] = field(default=0, ignore_default=True)
"""Number of epochs without improvement after which we stop learning.
Should be a multiple of validation_interval or 0 (no early stopping)"""
def __validate__(self):
assert (
self.early_stop % self.validation_interval == 0
), "Early stop should be a multiple of the validation interval"
def initialize(self, learner: Learner, context: TrainerContext):
super().initialize(learner, context)
self.scope = f"validation/{self.id}"
self.bestpath.mkdir(exist_ok=True, parents=True)
self.batcher_worker = self.batcher.initialize(self.batch_size)
# Checkpoint start
try:
with self.info.open("rt") as fp:
self.top: Dict[str, Any] = json.load(fp)
except Exception:
self.top = None
def init_task(self, learner: "Learner", dep):
return dep(
ModuleLoader(
value=learner.model,
path=self.bestpath / TrainState.MODEL_PATH,
)
)
def update_metrics(self, metrics: Dict[str, float]):
if self.top:
# Just use another key
metrics[f"{self.id}/final"] = self.top["value"]
def should_stop(self, epoch=0):
if self.early_stop > 0 and self.top:
epochs_since_imp = (epoch or self.context.epoch) - self.top["epoch"]
if epochs_since_imp >= self.early_stop:
return LearnerListenerStatus.STOP
return LearnerListenerStatus.DONT_STOP
def reducer(self, records, metrics: Metrics):
"""Combines a forward and backard
This method can be implemented by specific trainers that use the gradient.
In that case the regularizer losses should be taken into account with
`self.add_losses`.
"""
# Restrict losses to this context
with self.context.losses() as losses:
# Compute the loss(es)
self.trainer.train_batch(records)
# Aggregate with previous values
nrecords = len(records)
total_loss = 0.0
names = []
for loss in losses:
total_loss += loss.weight * loss.value
names.append(loss.name)
metrics.add(
ScalarMetric(
f"loss/{loss.name}", float(loss.value.item()), nrecords
)
)
# Reports the main metric
if len(names) > 1:
names.sort()
metrics.add(ScalarMetric("loss", float(total_loss.item()), nrecords))
return metrics
def __call__(self, state: TrainState):
# Check that we did not stop earlier (when loading from checkpoint / if other
# listeners have not stopped yet)
if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP:
return LearnerListenerStatus.STOP
if state.epoch % self.validation_interval == 0:
# Compute validation metrics
metrics = Metrics()
state.model.eval() # nosec
with torch.no_grad():
for batch in self.data.__batch_iter__(self.batch_size):
self.batcher_worker.reduce(
batch, self.reducer, metrics, raise_oom=False
)
metrics.report(state.step, self.context.writer, self.id)
# Get the current value
if len(metrics.metrics) == 1:
value = next(iter(metrics.metrics.values())).compute()
else:
value = metrics.metrics["loss"].compute()
# Update the top validation
if state.epoch >= self.warmup:
topstate = self.top
if topstate is None or value > topstate["value"]:
# Save the new top JSON
self.top = {"value": value, "epoch": self.context.epoch}
# Copy in corresponding directory
logger.info(f"Saving the checkpoint {state.epoch}")
self.context.copy(self.bestpath)
# Update information
with self.info.open("wt") as fp:
json.dump(self.top, fp)
# Early stopping?
return self.should_stop()