from abc import abstractmethod
from typing import Dict, Iterator, List, Optional
from experimaestro import Config, Param, Meta, field
import torch
import torch.nn as nn
import numpy as np
from lightning.fabric.wrappers import _FabricDataLoader
from xpm_torch import Module, Sampler
from xpm_torch.metrics import ScalarMetric
from xpm_torch.utils.logging import EasyLogger
from xpm_torch.batchers import Batcher
from xpm_torch.trainers.context import (
TrainingHook,
TrainerContext,
)
from torchdata.stateful_dataloader import StatefulDataLoader
[docs]
class Trainer(Config, EasyLogger):
"""Generic trainer"""
hooks: Param[List[TrainingHook]] = field(default=[], ignore_default=True)
"""Hooks for this trainer: this includes the losses, but can be adapted for
other uses
The specific list of hooks depends on the specific trainer
"""
model: Param[Optional[Module]] = field(default=None, ignore_default=True)
"""If the model to optimize is different from the model passsed to Learn,
this parameter can be used – initialization is still expected to be done at
the learner level"""
def initialize(
self,
random: np.random.RandomState,
context: TrainerContext,
):
self.random = random
# Generic style
if self.model is None:
self.model = context.state.model
# Old style (to be deprecated)
self.ranker = self.model
self.context = context
for hook in self.hooks:
self.context.add_hook(hook)
def to(self, device):
"""Change the computing device (if this is needed)"""
for hook in self.context.hooks(nn.Module):
hook.to(device)
@abstractmethod
def iter_batches(self) -> Iterator:
"""Returns a (serializable) iterator over batches"""
...
@abstractmethod
def process_batch(self, batch):
"""Process a batch of records, return the loss value that will be backpropagated"""
...
@abstractmethod
def load_state_dict(self, state: Dict): ...
@abstractmethod
def state_dict(self): ...
[docs]
class LossTrainer(Trainer):
"""Trainer based on a loss function
Uses StatefulDataLoader + IterableDataset for data loading.
"""
batcher: Meta[Batcher] = field(default_factory=Batcher.C)
"""How to batch samples together"""
sampler: Param[Sampler]
"""The sampler to use"""
batch_size: Param[int] = field(default=16, ignore_default=True)
"""Number of samples per batch"""
num_workers: Param[int] = field(default=2, ignore_default=True)
"""Number of DataLoader workers"""
dataloader: Optional[StatefulDataLoader] = None
"""StatefulDataLoader for training data"""
def initialize(
self,
random: np.random.RandomState,
context: TrainerContext,
):
"""Initialize the trainer, create the dataloader and initialize the loss function
Args:
random: Random state for initialization
context: TrainerContext for the training process
"""
super().initialize(random, context)
self.sampler.initialize(random)
self.batcher_worker = self.batcher.initialize(self.batch_size)
def _create_dataloader(self, dataset, collate_fn):
"""Create a StatefulDataLoader from a dataset and collate function."""
self.dataloader = StatefulDataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=collate_fn,
num_workers=self.num_workers,
)
def iter_batches(self) -> Iterator:
"""Returns an iterator over batches."""
assert self.dataloader is not None, (
"dataloader not initialized — call _create_dataloader() first"
)
return iter(self.dataloader)
def load_state_dict(self, state: Dict):
if "dataloader" in state and self.dataloader is not None:
if isinstance(self.dataloader, _FabricDataLoader):
# If the dataloader is wrapped with Fabric, we need to load the state dict into the original dataloader
self.dataloader._dataloader.load_state_dict(state["dataloader"])
else:
self.dataloader.load_state_dict(state["dataloader"])
def state_dict(self):
assert self.dataloader is not None, "dataloader not initialized"
if isinstance(self.dataloader, _FabricDataLoader):
# If the dataloader is wrapped with Fabric, we need to get the state dict from the original dataloader
dataloader_state = self.dataloader._dataloader.state_dict()
else:
dataloader_state = self.dataloader.state_dict()
return {"dataloader": dataloader_state}
def process_batch(self, batch: list):
"""Compute loss for a given batch of records - called by the learner.
important: this method uses the batcher to split the batch into microbatches when needed
"""
self.batcher_worker.process(batch, self.process_microbatch, raise_oom=True)
def process_microbatch(self, records: list):
"""Combines a forward and backard
This method can be implemented by specific trainers that use the gradient.
In that case the regularizer losses should be taken into account with
`self.add_losses`.
"""
# Restrict losses to this context
with self.context.losses() as losses:
self.train_batch(records)
nrecords = len(records)
total_loss = torch.tensor(0.0, device=self.context.fabric.device)
names = []
for loss in losses:
total_loss += loss.weight * loss.value
names.append(loss.name)
self.context.add_metric(
ScalarMetric(f"{loss.name}", float(loss.value.item()), nrecords)
)
# Reports the main metric
if len(names) > 1:
names.sort()
self.context.add_metric(
ScalarMetric("+".join(names), float(total_loss.item()), nrecords)
)
self.context.backward(self.context.state.optimizer.scale(total_loss))
def train_batch(self, records):
"""This method should report"""
raise NotImplementedError