Source code for xpm_torch.losses.pairwise


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.functional import Tensor
from experimaestro import field, Config, Param

from xpm_torch.utils.logging import EasyLogger
from xpm_torch.losses import Loss, bce_with_logits_loss, ModuleOutputType
from xpm_torch.trainers import TrainerContext


[docs] class PairwiseLoss(Config, nn.Module): """Base class for any pairwise loss""" NAME = "?" weight: Param[float] = field(default=1.0, ignore_default=True) """The weight :math:`w` with which the loss is multiplied (useful when combining with other ones)""" def initialize(self, ranker): pass def process(self, scores: Tensor, context: TrainerContext): value = self.compute(scores, context) context.add_loss(Loss(f"pair-{self.NAME}", value, self.weight)) def compute(self, scores: Tensor, info: TrainerContext) -> Tensor: """Computes the loss :param scores: A (batch x 2) tensor (positive/negative) :param info: the trainer context :return: a torch scalar """ raise NotImplementedError()
### Child Losses
[docs] class CrossEntropyLoss(PairwiseLoss): r"""Cross-Entropy Loss Computes the cross-entropy loss Classification loss (relevant vs non-relevant) where the logit is equal to the difference between the relevant and the non relevant document (or equivalently, softmax then mean log probability of relevant documents) Reference: C. Burges et al., “Learning to rank using gradient descent,” 2005. *warning*: this loss assumes the score returned by the scorer is a logit .. math:: \frac{w}{N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} """ NAME = "cross-entropy" def compute(self, rel_scores_by_record, info: TrainerContext): target = ( torch.zeros(rel_scores_by_record.shape[0]) .long() .to(rel_scores_by_record.device) ) return F.cross_entropy(rel_scores_by_record, target, reduction="mean")
[docs] class HingeLoss(PairwiseLoss): r"""Hinge (or max-margin) loss .. math:: \frac{w}{N} \sum_{(s^+,s-)} \max(0, m - (s^+ - s^-)) """ NAME = "hinge" margin: Param[float] = field(default=1.0, ignore_default=True) """The margin for the Hinge loss""" def compute(self, rel_scores_by_record, info: TrainerContext): return F.relu( self.margin - rel_scores_by_record[:, 0] + rel_scores_by_record[:, 1] ).mean()
[docs] class PointwiseCrossEntropyLoss(PairwiseLoss, EasyLogger): r"""Point-wise cross-entropy loss This is a point-wise loss adapted as a pairwise one. This loss adapts to the ranker output type: - If real, uses a BCELossWithLogits (sigmoid transformation) - If probability, uses the BCELoss - If log probability, uses a BCEWithLogLoss .. math:: \frac{w}{2N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} + \log \frac{\exp(s^-)}{\exp(s^+)+\exp(s^-)} """ NAME = "pointwise-cross-entropy" def initialize(self, ranker): super().initialize(ranker) self.rankerOutputType = ranker.outputType if ranker.outputType == ModuleOutputType.REAL: self.logger.info("Ranker outputs logits: using BCEWithLogitsLoss") self.loss = nn.BCEWithLogitsLoss() elif ranker.outputType == ModuleOutputType.PROBABILITY: self.logger.info("Ranker outputs probabilities: using BCELoss") self.loss = nn.BCELoss() elif ranker.outputType == ModuleOutputType.LOG_PROBABILITY: self.logger.info("Ranker outputs probabilities: using BCEWithLogLoss") self.loss = bce_with_logits_loss else: raise Exception("Not implemented") def compute(self, rel_scores_by_record, info: TrainerContext): device = rel_scores_by_record.device dim = rel_scores_by_record.shape[0] target = torch.cat( (torch.ones(dim, device=device), torch.zeros(dim, device=device)) ) return self.loss(rel_scores_by_record.T.flatten(), target)