import threading
import logging
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
from pathlib import Path
import torch
import re
from experimaestro import (
Config,
Param,
Constant,
tagspath,
Task,
experiment,
RunMode,
field,
)
from experimaestro.scheduler import Job, Listener
from experimaestro.utils import cleanupdir
from experimaestro.scheduler.services import WebService
from xpm_torch.utils.logging import LazyJoin
from xpm_torch.context import Hook, Context
from xpm_torch.utils.utils import foreach
from xpm_torch.metrics import ScalarMetric
from .schedulers import Scheduler
from .module import Module, ModuleLoader as ModuleLoader
if TYPE_CHECKING:
from xpm_torch.trainers import TrainerContext
logger = logging.getLogger(__name__)
[docs]
class Optimizer(Config):
def __call__(self, parameters) -> torch.optim.Optimizer:
raise NotImplementedError()
[docs]
class SGD(Optimizer):
"""Wrapper for SGD optimizer in Pytorch"""
lr: Param[float] = field(default=1e-5, ignore_default=True)
"""Learning rate"""
weight_decay: Param[float] = field(default=0.0, ignore_default=True)
"""Weight decay (L2)"""
def __call__(self, parameters):
from torch.optim import SGD
return SGD(parameters, lr=self.lr, weight_decay=self.weight_decay)
[docs]
class Adafactor(Optimizer):
"""Wrapper for Adafactor optimizer in Transformers library
See :class:`transformers.optimization.Adafactor` for full documentation
"""
lr: Param[Optional[float]] = field(default=None, ignore_default=True)
"""Learning rate"""
weight_decay: Param[float] = field(default=0.0, ignore_default=True)
"""Weight decay (L2)"""
relative_step: Param[bool] = field(default=True, ignore_default=True)
"""If true, time-dependent learning rate is computed instead of external
learning rate"""
def __call__(self, parameters):
from transformers.optimization import Adafactor
return Adafactor(
parameters,
lr=self.lr,
weight_decay=self.weight_decay,
relative_step=self.relative_step,
)
[docs]
class Adam(Optimizer):
"""Wrapper for Adam optimizer in PyTorch"""
lr: Param[float] = field(default=1e-3, ignore_default=True)
"""Learning rate"""
weight_decay: Param[float] = field(default=0.0, ignore_default=True)
"""Weight decay (L2)"""
eps: Param[float] = field(default=1e-8, ignore_default=True)
def __call__(self, parameters):
from torch.optim import Adam
return Adam(
parameters, lr=self.lr, weight_decay=self.weight_decay, eps=self.eps
)
[docs]
class AdamW(Optimizer):
"""Adam optimizer that takes into account the regularization
See the `PyTorch documentation
<https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html>`_
"""
lr: Param[float] = field(default=1e-3, ignore_default=True)
weight_decay: Param[float] = field(default=1e-2, ignore_default=True)
eps: Param[float] = field(default=1e-8, ignore_default=True)
def __call__(self, parameters):
from torch.optim import AdamW
return AdamW(
parameters, lr=self.lr, weight_decay=self.weight_decay, eps=self.eps
)
[docs]
class ParameterFilter(Config):
"""One abstract class which doesn't do the filtrage"""
def __call__(self, name, params) -> bool:
"""Returns true if the parameters should be optimized with the
associated optimizer"""
return True
[docs]
class RegexParameterFilter(ParameterFilter):
"""gives the name of the model to do the filtrage
Precondition: Only and just one of the includes and excludes can be None"""
includes: Param[Optional[List[str]]] = field(default=None, ignore_default=True)
"""The str of params to be included from the model"""
excludes: Param[Optional[List[str]]] = field(default=None, ignore_default=True)
"""The str of params to be excludes from the model"""
def __init__(self):
self.name = set()
def __validate__(self):
return self.includes or self.excludes
def __repr__(self) -> str:
return f"RegexParameterFilter({self.includes}, {self.excludes})"
def __call__(self, name, params) -> bool:
# Look first at included
if self.includes:
for regex in self.includes:
if re.search(regex, name):
return True
# Include if not excluded
if not self.excludes:
return False
for regex in self.excludes:
if re.search(regex, name):
return False
return True
[docs]
class ParameterOptimizer(Config):
"""Associates an optimizer with a list of parameters to optimize"""
optimizer: Param[Optimizer]
"""The optimizer"""
scheduler: Param[Optional[Scheduler]]
"""The optional scheduler"""
module: Param[Optional[Module]]
"""The module from which parameters should be extracted"""
filter: Param[Optional[ParameterFilter]] = field(default_factory=ParameterFilter.C)
"""How parameters should be selected for this (by default, use them all)"""
def create_optimizer(
self, module: Module, filter: Callable[[str, Any], bool]
) -> torch.optim.Optimizer:
"""Returns a (pytorch) optimizer"""
module = self.module or module
params = {
name: param
for name, param in module.named_parameters()
if (self.filter is None or self.filter(name, param)) and filter(name, param)
}
if not params:
logger.warning(
"Parameter list: %s", [name for name, _ in module.named_parameters()]
)
raise RuntimeError(f"Parameter list is empty with {self.filter}")
logger.debug(
"Optimizing with %s parameters [%s]",
self.filter,
LazyJoin(",", params.keys()),
)
optimizer = self.optimizer(params.values())
return optimizer
class DuplicateParameterFilter:
"""Filters out already optimized parameters"""
def __init__(self):
self.parameters = set()
def __call__(self, name, params):
if params in self.parameters:
return False
self.parameters.add(params)
return True
[docs]
class OptimizationHook(Hook):
"""Base class for all optimization hooks"""
pass
[docs]
class GradientHook(OptimizationHook):
"""Hooks that are called when the gradient is computed
The gradient is guaranteed to be unscaled in this case.
"""
pass
[docs]
class GradientClippingHook(GradientHook):
"""Gradient clipping"""
max_norm: Param[float]
"""Maximum norm for gradient clipping"""
version: Constant[str] = "1"
"""version of the Hook"""
def __call__(self, main: "ScheduledOptimizer"):
# torch.nn.utils.clip_grad_norm_(main.module.parameters(), self.max_norm)
for optimizer in main.optimizers:
main.trainer_context.fabric.clip_gradients(
main.module,
optimizer,
max_norm=self.max_norm
)
[docs]
class GradientLogHook(GradientHook):
""" "Log the gradient norm"""
name: Param[str] = field(default="gradient_norm", ignore_default=True)
def __call__(self, main: "ScheduledOptimizer"):
sum_norms = 0.0
n_params = 0
with torch.no_grad():
for param in main.module.parameters():
if param.grad is not None:
n_params += param.grad.numel()
sum_norms += param.grad.numel() * param.grad.norm() ** 2
assert (
n_params > 0
), "No parameters with gradients found for logging the gradient norm"
main.trainer_context.writer.add_scalar(
self.name, sum_norms / n_params, main.trainer_context.state.step
)
class ScheduledOptimizer:
def initialize(
self,
param_optimizers: List[ParameterOptimizer],
num_training_steps: int,
module: Module,
use_scaler: bool,
hooks: List[OptimizationHook] = [],
trainer_context: Optional["TrainerContext"] = None,
):
self.schedulers = []
self.scheduler_factories = []
self.optimizers = []
self.scheduler_steps = -1 # Number of scheduler steps
self.num_training_steps = num_training_steps
self.module = module
self.context = Context(hooks)
self.trainer_context = trainer_context
try:
next(module.parameters())
except StopIteration:
raise RuntimeError(f"No parameters to optimize in the module {module}")
filter = DuplicateParameterFilter()
for param_optimizer in param_optimizers:
optimizer = param_optimizer.create_optimizer(module, filter)
self.optimizers.append(optimizer)
self.scheduler_factories.append(param_optimizer.scheduler)
self.reset_schedulers()
assert len(self.schedulers) == len(self.optimizers)
if use_scaler:
logger.info("Using GradScaler when optimizing")
self.scaler = torch.cuda.amp.GradScaler() if use_scaler else None
def load_state_dict(self, state):
for optimizer, optimizer_state in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(optimizer_state)
if self.scaler is not None:
self.scaler.load_state_dict(state["scaler"])
# Re-create schedulers
self.scheduler_steps = state["scheduler_steps"]
self.reset_schedulers()
def reset_schedulers(self):
self.schedulers = []
for optimizer, scheduler_factory in zip(
self.optimizers, self.scheduler_factories
):
if scheduler_factory is None:
self.schedulers.append(None)
else:
self.schedulers.append(
scheduler_factory(
optimizer,
self.num_training_steps,
last_epoch=self.scheduler_steps,
)
)
def state_dict(self):
return {
"optimizers": [optimizer.state_dict() for optimizer in self.optimizers],
"scaler": None if self.scaler is None else self.scaler.state_dict(),
"scheduler_steps": self.scheduler_steps,
}
def scale(self, loss: torch.Tensor):
if self.scaler is None:
return loss
return self.scaler.scale(loss)
def zero_grad(self):
"""Zero-grad for all optimizers"""
for optimizer in self.optimizers:
optimizer.zero_grad()
def optimizer_step(self, context: "TrainerContext"):
"""Performs an optimizer step (using the scaler if defined)"""
if self.scaler is None:
# Apply gradient hooks
foreach(
self.context.hooks(GradientHook),
lambda hook: hook(self),
)
for optimizer in self.optimizers:
optimizer.step()
else:
# Unscale first
for optimizer in self.optimizers:
self.scaler.unscale_(optimizer)
# Apply gradient hooks
foreach(
self.context.hooks(GradientHook),
lambda hook: hook(self),
)
# Step
for optimizer in self.optimizers:
self.scaler.step(optimizer)
context.add_metric(
ScalarMetric("gradient/scaler", self.scaler.get_scale(), 1)
)
self.scaler.update()
def scheduler_step(self, context: "TrainerContext"):
"""Performs a step for all the schedulers"""
for ix, scheduler in enumerate(self.schedulers):
if scheduler is not None:
for p_ix, lr in enumerate(scheduler.get_last_lr()):
context.add_metric(
ScalarMetric(f"gradient/scheduler/{ix+1}/{p_ix+1}", lr, 1)
)
scheduler.step()
self.scheduler_steps += 1
Optimizers = Union[ParameterOptimizer, Optimizer, List[ParameterOptimizer]]
"""Defines a set of optimizers"""
def get_optimizers(optimizers: Optimizers):
"""Returns a list of ParameterOptimizer"""
if isinstance(optimizers, list):
return optimizers
if isinstance(optimizers, ParameterOptimizer):
return [optimizers]
return [ParameterOptimizer(optimizer=optimizers)]
class TensorboardServiceListener(Listener):
def __init__(self, source: Path, target: Path):
self.source = source
self.target = target
def job_state(self, job: Job):
if not job.state.notstarted():
if not self.source.is_symlink():
try:
self.source.symlink_to(self.target)
except Exception:
logger.exception(
"Cannot symlink %s to %s", self.source, self.target
)
class TensorboardService(WebService):
id = "tensorboard"
def __init__(self, xp: experiment, path: Path):
super().__init__()
self.path = path
self.url = None
self.run_mode = xp.run_mode
if self.run_mode == RunMode.NORMAL:
cleanupdir(self.path)
self.path.mkdir(exist_ok=True, parents=True)
logger.info("You can monitor learning with:")
logger.info("tensorboard --logdir=%s", self.path)
def add(self, task: Task, path: Path):
# Wait until config has started
if self.run_mode == RunMode.NORMAL:
if job := task.__xpm__.job:
if job.scheduler is not None:
tag_path = tagspath(task)
if tag_path:
job.scheduler.addlistener(
TensorboardServiceListener(self.path / tag_path, path)
)
else:
logger.error(
"The task is not associated with tags: "
"cannot link to tensorboard data"
)
else:
logger.debug("No scheduler: not adding the tensorboard data")
else:
logger.error(
"Task was not started: cannot link to tensorboard job path"
)
def description(self):
return "Tensorboard service"
def close(self):
if self.server and self.run_mode == RunMode.NORMAL:
self.server.shutdown()
def _serve(self, running: threading.Event):
if self.run_mode != RunMode.NORMAL:
return
import tensorboard as tb
try:
logger.info("Starting %s service", self.id)
self.program = tb.program.TensorBoard()
self.program.configure(
host="localhost",
logdir=str(self.path.absolute()),
path_prefix=f"/services/{self.id}",
port=0,
)
self.server = self.program._make_server()
self.url = self.server.get_url()
running.set()
self.server.serve_forever()
except Exception:
logger.exception("Error while starting tensorboard")
running.set()