Source code for xpm_torch.learner

from enum import Enum
from time import perf_counter
import torch
import numpy as np
from pathlib import Path
from typing import Dict, Iterator, List, NamedTuple, Any, Optional
from experimaestro import (
    Task,
    Config,
    Param,
    pathgenerator,
    Annotated,
    tqdm,
    field,
    Meta,
)

import lightning as L
from lightning.fabric.strategies.strategy import Strategy as l_Strategy

from xpm_torch import Random
from xpm_torch.configuration import FabricConfiguration
from xpm_torch.metrics import Metrics, ScalarMetric
from .batchers import RecoverableOOMError
from .optim import (
    Module,
    ModuleLoader,
    ParameterOptimizer,
    ScheduledOptimizer,
    OptimizationHook,
)
from xpm_torch.context import Hook, InitializationHook
from xpm_torch.utils.logging import EasyLogger

from xpm_torch.trainers.context import (
    StepTrainingHook,
    TrainState,
    TrainerContext,
)
from xpm_torch.trainers import Trainer

import logging
logger = logging.getLogger(__name__)


[docs] class Strategy(Config, l_Strategy): """A Lightning strategy""" pass
class LearnerListenerStatus(Enum): NO_DECISION = 0 STOP = 1 DONT_STOP = 2 def update(self, other: "LearnerListenerStatus") -> "LearnerListenerStatus": return LearnerListenerStatus(max(self.value, other.value))
[docs] class CheckpointSettings(Config): """Settings for a checkpoint-specific ModuleLoader.""" epoch: Param[Optional[int]] = field(default=None, ignore_default=True) """The epoch of the checkpoint"""
[docs] class LearnerListener(Config): """Hook for learner Performs some operations after a learning epoch""" id: Meta[str] """Unique ID to identify the listener (ignored for signature)""" def initialize(self, learner: "Learner", context: TrainerContext): self.learner = learner self.context = context def __call__(self, state: TrainState) -> LearnerListenerStatus: """Process and returns whether the training process should stop""" return LearnerListenerStatus.NO_DECISION def update_metrics(self, metrics: Dict[str, float]): """Add metrics""" pass def init_task(self, learner: "Learner", dep, add_action): """Returns the initialization task that loads the associated checkpoint :param learner: The learner object :param dep: The function that adds a dependency :param add_action: Function to register an action """ return None
class LearnerOutput(NamedTuple): """The data structure for the output of a learner. It contains a dictionary where the key is the name of the listener and the value is the output of that listener. It also allows to access the checkpoints saved during the training""" listeners: Dict[str, Any] learned_model: ModuleLoader checkpoints: Dict[str, Any]
[docs] class Learner(Task, EasyLogger): """Model Learner The learner task is generic, and takes two main arguments: (1) the model defines the model (e.g. DRMM), and (2) the trainer defines how the model should be trained (e.g. pointwise, pairwise, etc.) When submitted, it returns a dictionary based on the `listeners` """ # Training random: Param[Random] """The random generator""" trainer: Param[Trainer] """Specifies how to train the model""" model: Param[Module] """Defines the model to be learned. If multiple models are used, one can use MultipleModel. """ max_epochs: Param[int] = field(default=1000, ignore_default=True) """Maximum number of epochs""" steps_per_epoch: Param[int] = field(default=128, ignore_default=True) """Number of steps for one epoch (after each epoch results are logged)""" optimizers: Param[List[ParameterOptimizer]] """The list of parameter optimizers""" listeners: Param[List[LearnerListener]] """Listeners are in charge of handling the validation of the model, and saving the relevant checkpoints""" checkpoint_interval: Param[int] = field(default=1, ignore_default=True) """Number of epochs between each checkpoint""" logpath: Annotated[Path, pathgenerator("runs")] """The path to the tensorboard logs""" checkpointspath: Annotated[Path, pathgenerator("checkpoints")] """The path to the checkpoints""" hooks: Param[List[Hook]] = field(default=[], ignore_default=True) """Global learning hooks :class:`Initialization hooks <xpm_torch.context.InitializationHook>` are called before and after the initialization of the trainer and listeners. """ fabric_config: Param[FabricConfiguration] = field(default_factory=FabricConfiguration.C) """Runtime configuration, managed by Fabric""" def __validate__(self): assert self.optimizers, "At least one optimizer should be defined" assert len(set(listener.id for listener in self.listeners)) == len( self.listeners ), "IDs of listeners should be unique" return super().__validate__() def __submit__(self, dep, add_action): """Submit the learner task and register export actions.""" learned_model = dep( self.model.loader_config( self.last_checkpoint_path / TrainState.MODEL_DIR ) ) # Register export action — the model controls the action type add_action(self.model.export_action(learned_model, default_name="last")) return LearnerOutput( listeners={ listener.id: listener.init_task(self, dep, add_action=add_action) for listener in self.listeners }, learned_model=learned_model, checkpoints={ interval: dep( self.model.loader_config( TrainerContext.get_checkpoint_path( self.checkpointspath, interval ) / TrainState.MODEL_DIR, settings=CheckpointSettings.C(epoch=interval), ) ) for interval in range(0, self.max_epochs, self.checkpoint_interval) }, ) @property def last_checkpoint_path(self): return self.checkpointspath / "last" def execute(self): """Main training loop, executed using the fabric context. The training process is stopped either by the listeners or when max_epoch is reached. """ # 1. Launch Fabric fabric = self.fabric_config.get_fabric() fabric.launch() self.optimizer = ScheduledOptimizer() self.only_cached = False self.context = TrainerContext( self.logpath, self.checkpointspath, self.max_epochs, self.steps_per_epoch, self.trainer, self.model, self.optimizer, fabric=fabric, ) for hook in self.hooks: self.context.add_hook(hook) # Call init hooks for hook in self.context.hooks(InitializationHook): hook.before(self.context) # Sets the random seed # WARNING - will still not be fully deterministic unless using (lot slower): # - torch.use_deterministic_algorithms(True) (PyTorch ≥1.8). # - torch.backends.cudnn.deterministic = True # - torch.backends.cudnn.benchmark = False. # can also use fabric.seed_everything(self.random.state.randint((2**32) - 1)) seed = self.random.state.randint((2**32) - 1) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Initialize the model and trainer with fabric.init_module(): self.trainer.initialize(self.random.state, self.context) # Wrap dataloader with Fabric for device placement (if using new path) if hasattr(self.trainer, "dataloader") and self.trainer.dataloader is not None: self.trainer.dataloader = fabric.setup_dataloaders(self.trainer.dataloader) for listener in self.listeners: listener.initialize(self, self.context) num_training_steps = self.max_epochs * self.steps_per_epoch self.optimizer.initialize( self.optimizers, num_training_steps, self.model, use_scaler=False, hooks=[hook for hook in self.hooks if isinstance(hook, OptimizationHook)], trainer_context=self.context, ) # wrap model and optimizers self.model, *self.optimizer.optimizers = fabric.setup( self.model, *self.optimizer.optimizers ) self.logger.info( f"Model is on device {self.model.device} using dtype {next(self.model.parameters()).dtype}" ) if torch.cuda.is_available(): # This is the definitive check for BF16 support supports_bf16 = torch.cuda.is_bf16_supported() self.logger.info(f"Hardware supports BF16: {supports_bf16}") if not supports_bf16 and "bf16" in fabric.precision.precision: self.logger.error( "CRITICAL: You are forcing BF16 on incompatible hardware!" ) for hook in self.context.hooks(InitializationHook): hook.after(self.context) self.logger.info("Starting to train") current = 0 state = None with tqdm( total=self.max_epochs, desc=f"Training ({self.max_epochs} epochs)" ) as tqdm_epochs: for state in self.iter_train(fabric): # Report progress tqdm_epochs.update(state.epoch - current) current = state.epoch if state.epoch >= 0 and not self.only_cached: message = f"epoch {state.epoch}" if state.cached: self.logger.debug(f"[train] [cached] {message}") else: self.logger.debug(f"[train] {message}") if state.epoch == -1: continue if not state.cached and state.epoch % self.checkpoint_interval == 0: # Save checkpoint if needed self.context.save_checkpoint() self.context.copy(self.last_checkpoint_path) # Call listeners decision = LearnerListenerStatus.NO_DECISION for listener in self.listeners: # listener.__call__ returns True if we should stop decision = decision.update(listener(state)) if decision == LearnerListenerStatus.STOP: self.logger.warning( "stopping after epoch {epoch} ({early_stop} epochs) since " "all listeners asked for it" ) break # Stop if max epoch is reached if self.context.epoch >= self.max_epochs: self.logger.warning( "stopping after epoch {max_epochs} (max_epoch)".format( **self.__dict__ ) ) break # End of the learning process if state is not None and not state.cached: # Set the hyper-parameters metrics = {} for listener in self.listeners: listener.update_metrics(metrics) self.context.writer.add_hparams(getattr(self, "__tags__", {}), metrics) def iter_train(self, fabric: L.Fabric) -> Iterator[TrainState]: """Infinite generator of training states: one per epoch, containing self.steps_per_epoch steps """ # Try to load a checkpoint if self.context.load_bestcheckpoint(self.max_epochs): yield self.context.state # Get an iterator over batches batch_iter = self.trainer.iter_batches() while True: # Step to the next epoch self.context.nextepoch() # Train for an epoch with tqdm( leave=False, total=self.steps_per_epoch, ncols=100, desc=f"Train - epoch #{self.context.epoch}", ) as pbar: # Put the model into training mode (just in case) self.context.state.model.train() # Epoch: loop over batches metrics = Metrics() start = perf_counter() for b in range(self.steps_per_epoch): # Get the next batch, recreate iterator on exhaustion try: batch = next(batch_iter) except StopIteration: batch_iter = self.trainer.iter_batches() batch = next(batch_iter) self.context.nextbatch() while True: try: # Computes the gradient, takes a step and collect metrics with self.context.step(metrics): # Call epoch hooks for hook in self.context.hooks(StepTrainingHook): hook.before(self.context) # Computes the gradient self.trainer.process_batch(batch) # Update metrics and counter pbar.update(1) break except RecoverableOOMError: logger.warning( "Recoverable OOM detected" " - re-running the training step" ) continue for hook in self.context.hooks(StepTrainingHook): hook.after(self.context) metrics.add( ScalarMetric("iter_per_seconds", self.steps_per_epoch / (perf_counter() - start), 1) ) # Yields the current state (after one epoch) # -> allows listeners to process it and decide whether to stop or not yield self.context.state # Report metrics over the epoch, and log them in tensorboard # Note that this is done after the listeners are called, so that they can update the metrics if needed (e.g., with validation results) metrics.report( self.context.state.step, self.context.writer, "train", )