Training

Learner

XPM Taskxpm_torch.learner.Learner(*, random, trainer, model, max_epochs, steps_per_epoch, optimizers, listeners, checkpoint_interval, hooks)[source]

Model Learner

The learner task is generic, and takes two main arguments: (1) the model defines the model (e.g. DRMM), and (2) the trainer defines how the model should be trained (e.g. pointwise, pairwise, etc.)

When submitted, it returns a dictionary based on the listeners

random: xpm_torch.base.Random

The random generator

trainer: xpm_torch.trainers.Trainer

Specifies how to train the model

model: xpm_torch.module.Module

Defines the model to be learned. If multiple models are used, one can use MultipleModel.

max_epochs: int = 1000

Maximum number of epochs

steps_per_epoch: int = 128

Number of steps for one epoch (after each epoch results are logged)

optimizers: List[xpm_torch.optim.ParameterOptimizer]

The list of parameter optimizers

listeners: List[xpm_torch.learner.LearnerListener]

Listeners are in charge of handling the validation of the model, and saving the relevant checkpoints

checkpoint_interval: int = 1

Number of epochs between each checkpoint

logpath: pathgenerated

The path to the tensorboard logs

checkpointspath: pathgenerated

The path to the checkpoints

hooks: List[xpm_torch.context.Hook] = []

Global learning hooks Initialization hooks are called before and after the initialization of the trainer and listeners.

fabric_config: xpm_torch.configuration.FabricConfiguration = xpm_torch.configuration.FabricConfiguration(precision=32-true, torch_fp32_precision=None, num_nodes=1, devices=auto, strategy=auto, accelerator=auto)generated

Runtime configuration, managed by Fabric

XPM Configxpm_torch.learner.LearnerListener(*, id)[source]

Hook for learner

Performs some operations after a learning epoch

id: str

Unique ID to identify the listener (ignored for signature)

XPM Configxpm_torch.learner.CheckpointSettings(*, epoch)[source]

Settings for a checkpoint-specific ModuleLoader.

epoch: int

The epoch of the checkpoint

XPM Configxpm_torch.learner.Strategy[source]

A Lightning strategy

Trainers

XPM Configxpm_torch.trainers.Trainer(*, hooks, model)[source]

Generic trainer

hooks: List[xpm_torch.trainers.context.TrainingHook] = []

Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer

model: xpm_torch.module.Module

If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level

XPM Configxpm_torch.trainers.LossTrainer(*, hooks, model, sampler, batch_size, num_workers)[source]

Trainer based on a loss function

Uses StatefulDataLoader + IterableDataset for data loading.

hooks: List[xpm_torch.trainers.context.TrainingHook] = []

Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer

model: xpm_torch.module.Module

If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level

batcher: xpm_torch.batchers.Batchergenerated

How to batch samples together

sampler: xpm_torch.base.Sampler

The sampler to use

batch_size: int = 16

Number of samples per batch

num_workers: int = 2

Number of DataLoader workers

XPM Configxpm_torch.trainers.multiple.MultipleTrainer(*, hooks, model, trainers)[source]

This trainer can be used to combine various trainers

hooks: List[xpm_torch.trainers.context.TrainingHook] = []

Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer

model: xpm_torch.module.Module

If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level

trainers: Dict[str, xpm_torch.trainers.Trainer]

The trainers

Training State

class xpm_torch.trainers.context.TrainState(model: Module, trainer: Trainer, optimizer: ScheduledOptimizer, epoch=0, steps=0)[source]

Represents a training state for serialization

epoch: int

The epoch

load(path, onlyinfo=False)[source]

Loads the state from disk

save(path)[source]

Save the state

property step

Returns the step for logging (number of steps)

steps: int

The number of steps (each epoch is composed of sptes)

class xpm_torch.trainers.context.TrainerContext(logpath: Path, path: Path, max_epoch: int, steps_per_epoch: int, trainer, model: Module, optimizer: ScheduledOptimizer, fabric: lightning.Fabric)[source]

Contains all the information about the training context for a spefic

This object is used when training to provide models and losses’ with extra information - as well as the possibility to add regularization losses

add_metric(metric: Metric)[source]

add a metric to be reported at the end of the step (e.g., for logging in tensorboard)

Parameters:

metric (Metric) – The metric to be added

copy(path: Path)[source]

Copy the state into another folder

load_bestcheckpoint(max_epoch: int)[source]

Try to find the best checkpoint to load (the highest lower than the epoch target)

metrics: Metrics | None

Metrics to be reported

property writer

Returns a tensorboard writer

by default, purges the entries beside the current epoch

Training Hooks

XPM Configxpm_torch.trainers.context.TrainingHook[source]

Base class for all training hooks

XPM Configxpm_torch.trainers.context.StepTrainingHook[source]

Base class for hooks called at each step (before/after)

XPM Configxpm_torch.trainers.context.ValidationHook[source]

Base class for all the validation hooks

XPM Configxpm_torch.trainers.context.InitializationTrainingHook[source]

Base class for hooks called at initialization

XPM Configxpm_torch.trainers.hooks.LayerFreezer(*, selector)[source]

This training hook class can be used to freeze a subset of model parameters

selector: xpm_torch.parameters.ParametersIterator

How to select the layers to freeze

XPM Configxpm_torch.trainers.hooks.LayerSharer(*, source, target)[source]

This training hook class can be used to share parameters

source: xpm_torch.parameters.ParametersIterator

The parameters to share

target: xpm_torch.parameters.ParametersIterator

The parameters to be shared

Validation

XPM Configxpm_torch.validation.ValidationSettings(*, listener, key)[source]

Settings for a validation-specific ModuleLoader.

Attached as settings on the loader to distinguish validation checkpoints from other loaders with the same model and path.

listener: xpm_torch.learner.LearnerListener

The listener (kept to change the loader identifier based on the learner listener configuration)

key: str

The metric key for this validation checkpoint

XPM Configxpm_torch.trainers.validation.TrainerValidationLoss(*, id, data, batch_size, trainer, warmup, validation_interval, early_stop)[source]

Generic trainer-based loss validation

id: str

Unique ID to identify the listener (ignored for signature)

data: xpm_torch.base.SampleIterator

The dataset to use

batcher: xpm_torch.batchers.Batchergenerated

How to batch samples together

batch_size: int

Batch size

trainer: xpm_torch.trainers.LossTrainer

The trainer

warmup: int = -1

How many epochs before actually computing the validation loss

bestpath: pathgenerated

Path to the best checkpoints

info: pathgenerated

Path to the JSON file that contains the metric values at each epoch

validation_interval: int = 1

Epochs between each validation

early_stop: int = 0

Number of epochs without improvement after which we stop learning. Should be a multiple of validation_interval or 0 (no early stopping)

Batching

XPM Configxpm_torch.batchers.Batcher[source]

Responsible for micro-batching when the batch does not fit in memory

The base class just does nothing (no adaptation)

XPM Configxpm_torch.batchers.PowerAdaptativeBatcher[source]

Starts with the provided batch size, and then divides in 2, 3, etc. until there is no more OOM

Batchwise Losses

XPM Configxpm_torch.losses.batchwise.BatchwiseLoss(*, weight)[source]
weight: float = 1.0

The weight of this loss

XPM Configxpm_torch.losses.batchwise.CrossEntropyLoss(*, weight)[source]
weight: float = 1.0

The weight of this loss

XPM Configxpm_torch.losses.batchwise.SoftmaxCrossEntropy(*, weight)[source]
weight: float = 1.0

The weight of this loss

Pairwise Losses

XPM Configxpm_torch.losses.pairwise.PairwiseLoss(*, weight)[source]

Base class for any pairwise loss

weight: float = 1.0

The weight \(w\) with which the loss is multiplied (useful when combining with other ones)

XPM Configxpm_torch.losses.pairwise.CrossEntropyLoss(*, weight)[source]

Cross-Entropy Loss

Computes the cross-entropy loss

Classification loss (relevant vs non-relevant) where the logit is equal to the difference between the relevant and the non relevant document (or equivalently, softmax then mean log probability of relevant documents) Reference: C. Burges et al., “Learning to rank using gradient descent,” 2005.

warning: this loss assumes the score returned by the scorer is a logit

\[\frac{w}{N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)}\]
weight: float = 1.0

The weight \(w\) with which the loss is multiplied (useful when combining with other ones)

XPM Configxpm_torch.losses.pairwise.HingeLoss(*, weight, margin)[source]

Hinge (or max-margin) loss

\[\frac{w}{N} \sum_{(s^+,s-)} \max(0, m - (s^+ - s^-))\]
weight: float = 1.0

The weight \(w\) with which the loss is multiplied (useful when combining with other ones)

margin: float = 1.0

The margin for the Hinge loss

XPM Configxpm_torch.losses.pairwise.PointwiseCrossEntropyLoss(*, weight)[source]

Point-wise cross-entropy loss

This is a point-wise loss adapted as a pairwise one.

This loss adapts to the ranker output type:

  • If real, uses a BCELossWithLogits (sigmoid transformation)

  • If probability, uses the BCELoss

  • If log probability, uses a BCEWithLogLoss

\[\frac{w}{2N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} + \log \frac{\exp(s^-)}{\exp(s^+)+\exp(s^-)}\]
weight: float = 1.0

The weight \(w\) with which the loss is multiplied (useful when combining with other ones)

Fabric Configuration

XPM Configxpm_torch.configuration.Strategy[source]
XPM Configxpm_torch.configuration.FabricConfigurationBase[source]
XPM Configxpm_torch.configuration.FabricConfiguration(*, precision, torch_fp32_precision, num_nodes, devices, strategy, accelerator)[source]

Describe the computation device

The backend is fabric, so the complete documentation can be found on https://lightning.ai/docs/fabric/stable/api/fabric_args.html

precision: str = 32-true

Precision to use, e.g., ‘16-mixed’, ‘bf16-mixed’, ‘32-true’: see Lightning documentation at https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision

torch_fp32_precision: str

Torch precision for torch.float32 operations, see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision Automatically set depending on fabric_config.precision if not set, but can be overridden if needed (e.g., to force TF32 on Ampere GPUs while using bf16 precision for other operations)

num_nodes: int = 1

Number of nodes

devices: str = auto

Configure the devices to run on. See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#devices for more details and options. Note that for multi-node training, you should specify the devices per node, e.g., devices=”4” for 4 GPUs per node, not devices=”16” for a total of 16 GPUs across 4 nodes.

strategy: str = auto

The strategy to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#strategy for more details and options.

accelerator: str = auto

The accelerator to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#accelerator for more details and options.