Source code for xpm_torch.configuration

from abc import ABC, abstractmethod
import logging
from typing import Optional, ParamSpec
from experimaestro import field, Config, Param, Meta
import lightning.fabric.strategies as strategies
import lightning as L
import torch

logger = logging.getLogger("xpm_torch.configuration")

P = ParamSpec("P")


[docs] class Strategy(Config, strategies.Strategy): pass
[docs] class FabricConfigurationBase(Config, ABC): def get_fabric(self, **kwargs) -> L.Fabric: return self._get_fabric(**kwargs) @abstractmethod def _get_fabric(self, **kwargs) -> L.Fabric: """Builds the Fabric object based on the configuration.""" ...
[docs] class FabricConfiguration(FabricConfigurationBase): """Describe the computation device The backend is fabric, so the complete documentation can be found on https://lightning.ai/docs/fabric/stable/api/fabric_args.html """ #parameters - change Learner output precision: Param[str] = field(default="32-true", ignore_default=True) """Precision to use, e.g., '16-mixed', 'bf16-mixed', '32-true': see Lightning documentation at https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision """ torch_fp32_precision: Param[Optional[str]] """Torch precision for torch.float32 operations, see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision Automatically set depending on fabric_config.precision if not set, but can be overridden if needed (e.g., to force TF32 on Ampere GPUs while using bf16 precision for other operations) """ # Meta - parameters - don't change output, just computing environment num_nodes: Meta[int] = field(default=1, ignore_default=True) """Number of nodes""" devices: Meta[str] = field(default="auto", ignore_default=True) """Configure the devices to run on. See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#devices for more details and options. Note that for multi-node training, you should specify the devices per node, e.g., devices="4" for 4 GPUs per node, not devices="16" for a total of 16 GPUs across 4 nodes. """ strategy: Meta[str] = field(default="auto", ignore_default=True) """The strategy to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#strategy for more details and options. """ accelerator: Meta[str] = field(default="auto", ignore_default=True) """The accelerator to use See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#accelerator for more details and options. """ is_built = False def _get_fabric(self, **kwargs) -> L.Fabric: """Builds the Fabric object and set the torch.float32 matmul precision based on the configuration. This is called by the Learner before launching the training loop """ if self.is_built: logger.warning("FabricConfiguration.get_Fabric called multiple times.") return None self.is_built = True if self.torch_fp32_precision is None: #auto set torch.float32 precision based on fabric precision (if not set explicitly) if self.precision in ["16-mixed", "bf16-mixed"]: self.torch_fp32_precision = "medium" else: self.torch_fp32_precision = "high" logger.info(f"Setting torch.fp32 matmul precision to '{self.torch_fp32_precision}' based on fabric precision '{self.precision}'") torch.set_float32_matmul_precision(self.torch_fp32_precision) fabric = L.Fabric( accelerator=self.accelerator, devices=self.devices, strategy=self.strategy, num_nodes=self.num_nodes, **kwargs ) logging.info(f"Using Fabric with accelerator={fabric.accelerator.__class__.__name__}, devices={fabric.world_size}, strategy={fabric.strategy.__class__.__name__}") return fabric