from abc import ABC, abstractmethod
import logging
from typing import Optional, ParamSpec
from experimaestro import field, Config, Param, Meta
import lightning.fabric.strategies as strategies
import lightning as L
import torch
logger = logging.getLogger("xpm_torch.configuration")
P = ParamSpec("P")
[docs]
class Strategy(Config, strategies.Strategy):
pass
[docs]
class FabricConfigurationBase(Config, ABC):
def get_fabric(self, **kwargs) -> L.Fabric:
return self._get_fabric(**kwargs)
@abstractmethod
def _get_fabric(self, **kwargs) -> L.Fabric:
"""Builds the Fabric object based on the configuration."""
...
[docs]
class FabricConfiguration(FabricConfigurationBase):
"""Describe the computation device
The backend is fabric, so the complete documentation can be found on
https://lightning.ai/docs/fabric/stable/api/fabric_args.html
"""
#parameters - change Learner output
precision: Param[str] = field(default="32-true", ignore_default=True)
"""Precision to use, e.g., '16-mixed', 'bf16-mixed', '32-true':
see Lightning documentation at https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
"""
torch_fp32_precision: Param[Optional[str]]
"""Torch precision for torch.float32 operations, see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Automatically set depending on fabric_config.precision if not set, but can be overridden if needed (e.g., to force TF32 on Ampere GPUs while using bf16 precision for other operations)
"""
# Meta - parameters - don't change output, just computing environment
num_nodes: Meta[int] = field(default=1, ignore_default=True)
"""Number of nodes"""
devices: Meta[str] = field(default="auto", ignore_default=True)
"""Configure the devices to run on.
See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#devices for more details and options.
Note that for multi-node training, you should specify the devices per node, e.g., devices="4" for 4 GPUs per node, not devices="16" for a total of 16 GPUs across 4 nodes.
"""
strategy: Meta[str] = field(default="auto", ignore_default=True)
"""The strategy to use
See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#strategy for more details and options.
"""
accelerator: Meta[str] = field(default="auto", ignore_default=True)
"""The accelerator to use
See https://lightning.ai/docs/fabric/stable/api/fabric_args.html#accelerator for more details and options.
"""
is_built = False
def _get_fabric(self, **kwargs) -> L.Fabric:
"""Builds the Fabric object and set the torch.float32 matmul precision based on the configuration.
This is called by the Learner before launching the training loop
"""
if self.is_built:
logger.warning("FabricConfiguration.get_Fabric called multiple times.")
return None
self.is_built = True
if self.torch_fp32_precision is None:
#auto set torch.float32 precision based on fabric precision (if not set explicitly)
if self.precision in ["16-mixed", "bf16-mixed"]:
self.torch_fp32_precision = "medium"
else:
self.torch_fp32_precision = "high"
logger.info(f"Setting torch.fp32 matmul precision to '{self.torch_fp32_precision}' based on fabric precision '{self.precision}'")
torch.set_float32_matmul_precision(self.torch_fp32_precision)
fabric = L.Fabric(
accelerator=self.accelerator,
devices=self.devices,
strategy=self.strategy,
num_nodes=self.num_nodes,
**kwargs
)
logging.info(f"Using Fabric with accelerator={fabric.accelerator.__class__.__name__}, devices={fabric.world_size}, strategy={fabric.strategy.__class__.__name__}")
return fabric