Training
Learner
- XPM Taskxpm_torch.learner.Learner(*, random, trainer, model, max_epochs, steps_per_epoch, optimizers, listeners, checkpoint_interval, hooks)[source]
Model Learner
The learner task is generic, and takes two main arguments: (1) the model defines the model (e.g. DRMM), and (2) the trainer defines how the model should be trained (e.g. pointwise, pairwise, etc.)
When submitted, it returns a dictionary based on the listeners
- random: xpm_torch.base.Random
The random generator
- trainer: xpm_torch.trainers.Trainer
Specifies how to train the model
- model: xpm_torch.module.Module
Defines the model to be learned. If multiple models are used, one can use MultipleModel.
- max_epochs: int = 1000
Maximum number of epochs
- steps_per_epoch: int = 128
Number of steps for one epoch (after each epoch results are logged)
- optimizers: List[xpm_torch.optim.ParameterOptimizer]
The list of parameter optimizers
- listeners: List[xpm_torch.learner.LearnerListener]
Listeners are in charge of handling the validation of the model, and saving the relevant checkpoints
- checkpoint_interval: int = 1
Number of epochs between each checkpoint
- logpath: pathgenerated
The path to the tensorboard logs
- checkpointspath: pathgenerated
The path to the checkpoints
- hooks: List[xpm_torch.context.Hook] = []
Global learning hooks
Initialization hooksare called before and after the initialization of the trainer and listeners.
- fabric_config: xpm_torch.configuration.FabricConfiguration = xpm_torch.configuration.FabricConfiguration(precision=32-true, torch_fp32_precision=None, num_nodes=1, devices=auto, strategy=auto, accelerator=auto)generated
Runtime configuration, managed by Fabric
- XPM Configxpm_torch.learner.LearnerListener(*, id)[source]
Hook for learner
Performs some operations after a learning epoch
- id: str
Unique ID to identify the listener (ignored for signature)
Trainers
- XPM Configxpm_torch.trainers.Trainer(*, hooks, model)[source]
Generic trainer
- hooks: List[xpm_torch.trainers.context.TrainingHook] = []
Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer
- model: xpm_torch.module.Module
If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level
- XPM Configxpm_torch.trainers.LossTrainer(*, hooks, model, sampler, batch_size, num_workers)[source]
Trainer based on a loss function
Uses StatefulDataLoader + IterableDataset for data loading.
- hooks: List[xpm_torch.trainers.context.TrainingHook] = []
Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer
- model: xpm_torch.module.Module
If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level
- batcher: xpm_torch.batchers.Batchergenerated
How to batch samples together
- sampler: xpm_torch.base.Sampler
The sampler to use
- batch_size: int = 16
Number of samples per batch
- num_workers: int = 2
Number of DataLoader workers
- XPM Configxpm_torch.trainers.multiple.MultipleTrainer(*, hooks, model, trainers)[source]
This trainer can be used to combine various trainers
- hooks: List[xpm_torch.trainers.context.TrainingHook] = []
Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer
- model: xpm_torch.module.Module
If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level
- trainers: Dict[str, xpm_torch.trainers.Trainer]
The trainers
Training State
- class xpm_torch.trainers.context.TrainState(model: Module, trainer: Trainer, optimizer: ScheduledOptimizer, epoch=0, steps=0)[source]
Represents a training state for serialization
- epoch: int
The epoch
- property step
Returns the step for logging (number of steps)
- steps: int
The number of steps (each epoch is composed of sptes)
- class xpm_torch.trainers.context.TrainerContext(logpath: Path, path: Path, max_epoch: int, steps_per_epoch: int, trainer, model: Module, optimizer: ScheduledOptimizer, fabric: lightning.Fabric)[source]
Contains all the information about the training context for a spefic
This object is used when training to provide models and losses’ with extra information - as well as the possibility to add regularization losses
- add_metric(metric: Metric)[source]
add a metric to be reported at the end of the step (e.g., for logging in tensorboard)
- Parameters:
metric (Metric) – The metric to be added
- load_bestcheckpoint(max_epoch: int)[source]
Try to find the best checkpoint to load (the highest lower than the epoch target)
- metrics: Metrics | None
Metrics to be reported
- property writer
Returns a tensorboard writer
by default, purges the entries beside the current epoch
Training Hooks
- XPM Configxpm_torch.trainers.context.StepTrainingHook[source]
Base class for hooks called at each step (before/after)
- XPM Configxpm_torch.trainers.context.ValidationHook[source]
Base class for all the validation hooks
- XPM Configxpm_torch.trainers.context.InitializationTrainingHook[source]
Base class for hooks called at initialization
- XPM Configxpm_torch.trainers.hooks.LayerFreezer(*, selector)[source]
This training hook class can be used to freeze a subset of model parameters
- selector: xpm_torch.parameters.ParametersIterator
How to select the layers to freeze
This training hook class can be used to share parameters
The parameters to share
The parameters to be shared
Validation
- XPM Configxpm_torch.validation.ValidationSettings(*, listener, key)[source]
Settings for a validation-specific ModuleLoader.
Attached as
settingson the loader to distinguish validation checkpoints from other loaders with the same model and path.- listener: xpm_torch.learner.LearnerListener
The listener (kept to change the loader identifier based on the learner listener configuration)
- key: str
The metric key for this validation checkpoint
- XPM Configxpm_torch.trainers.validation.TrainerValidationLoss(*, id, data, batch_size, trainer, warmup, validation_interval, early_stop)[source]
Generic trainer-based loss validation
- id: str
Unique ID to identify the listener (ignored for signature)
- data: xpm_torch.base.SampleIterator
The dataset to use
- batcher: xpm_torch.batchers.Batchergenerated
How to batch samples together
- batch_size: int
Batch size
- trainer: xpm_torch.trainers.LossTrainer
The trainer
- warmup: int = -1
How many epochs before actually computing the validation loss
- bestpath: pathgenerated
Path to the best checkpoints
- info: pathgenerated
Path to the JSON file that contains the metric values at each epoch
- validation_interval: int = 1
Epochs between each validation
- early_stop: int = 0
Number of epochs without improvement after which we stop learning. Should be a multiple of validation_interval or 0 (no early stopping)
Batching
- XPM Configxpm_torch.batchers.Batcher[source]
Responsible for micro-batching when the batch does not fit in memory
The base class just does nothing (no adaptation)
- XPM Configxpm_torch.batchers.PowerAdaptativeBatcher[source]
Starts with the provided batch size, and then divides in 2, 3, etc. until there is no more OOM
Batchwise Losses
- XPM Configxpm_torch.losses.batchwise.BatchwiseLoss(*, weight)[source]
- weight: float = 1.0
The weight of this loss
Pairwise Losses
- XPM Configxpm_torch.losses.pairwise.PairwiseLoss(*, weight)[source]
Base class for any pairwise loss
- weight: float = 1.0
The weight \(w\) with which the loss is multiplied (useful when combining with other ones)
- XPM Configxpm_torch.losses.pairwise.CrossEntropyLoss(*, weight)[source]
Cross-Entropy Loss
Computes the cross-entropy loss
Classification loss (relevant vs non-relevant) where the logit is equal to the difference between the relevant and the non relevant document (or equivalently, softmax then mean log probability of relevant documents) Reference: C. Burges et al., “Learning to rank using gradient descent,” 2005.
warning: this loss assumes the score returned by the scorer is a logit
\[\frac{w}{N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)}\]- weight: float = 1.0
The weight \(w\) with which the loss is multiplied (useful when combining with other ones)
- XPM Configxpm_torch.losses.pairwise.HingeLoss(*, weight, margin)[source]
Hinge (or max-margin) loss
\[\frac{w}{N} \sum_{(s^+,s-)} \max(0, m - (s^+ - s^-))\]- weight: float = 1.0
The weight \(w\) with which the loss is multiplied (useful when combining with other ones)
- margin: float = 1.0
The margin for the Hinge loss
- XPM Configxpm_torch.losses.pairwise.PointwiseCrossEntropyLoss(*, weight)[source]
Point-wise cross-entropy loss
This is a point-wise loss adapted as a pairwise one.
This loss adapts to the ranker output type:
If real, uses a BCELossWithLogits (sigmoid transformation)
If probability, uses the BCELoss
If log probability, uses a BCEWithLogLoss
\[\frac{w}{2N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} + \log \frac{\exp(s^-)}{\exp(s^+)+\exp(s^-)}\]- weight: float = 1.0
The weight \(w\) with which the loss is multiplied (useful when combining with other ones)
Fabric Configuration
- XPM Configxpm_torch.configuration.FabricConfiguration(*, precision, torch_fp32_precision, num_nodes, devices, strategy, accelerator)[source]
Describe the computation device
The backend is fabric, so the complete documentation can be found on https://lightning.ai/docs/fabric/stable/api/fabric_args.html
- precision: str = 32-true
Precision to use, e.g., ‘16-mixed’, ‘bf16-mixed’, ‘32-true’: see Lightning documentation at https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
- torch_fp32_precision: str
Torch precision for torch.float32 operations, see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision Automatically set depending on fabric_config.precision if not set, but can be overridden if needed (e.g., to force TF32 on Ampere GPUs while using bf16 precision for other operations)
- num_nodes: int = 1
Number of nodes
- devices: str = auto
Configure the devices to run on. See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#devices for more details and options. Note that for multi-node training, you should specify the devices per node, e.g., devices=”4” for 4 GPUs per node, not devices=”16” for a total of 16 GPUs across 4 nodes.
- strategy: str = auto
The strategy to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#strategy for more details and options.
- accelerator: str = auto
The accelerator to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#accelerator for more details and options.