"""Management of the arguments (params, options, etc) associated with the XPM objects"""
import warnings
from typing import Optional, TypeVar, TYPE_CHECKING, Callable, Any
from experimaestro.typingutils import get_optional
from pathlib import Path
from typing import Annotated
if TYPE_CHECKING:
import experimaestro.core.types
from experimaestro.core.partial import ParameterGroup
# Track deprecation warnings per module (max 10 per module)
_deprecation_warning_counts: dict[str, int] = {}
_MAX_WARNINGS_PER_MODULE = 10
class Argument:
"""Represents an argument of a configuration or task"""
objecttype: Optional["experimaestro.core.types.ObjectType"]
"""The object for which this argument was declared"""
def __init__(
self,
name,
type: "experimaestro.core.types.Type",
required=None,
help=None,
generator=None,
ignored=None,
field_or_default=None,
checker=None,
constant=False,
is_data=False,
overrides=False,
groups: set["ParameterGroup"] = None,
):
"""Creates a new argument
Args:
name (str): The name of the argument
type (experimaestro.core.types.Type): The type of the argument
required (bool, optional): True if required (if None, determines
automatically). Defaults to None.
help (str, optional): Help string. Defaults to None.
generator (Generator, optional): The value generator (e.g. for
paths). Defaults to None.
ignored (bool, optional): True if ignored (if None, computed
automatically). Defaults to None.
field_or_default (any | field, optional): Default value or field
object containing default information. Defaults to None.
checker (any, optional): Value checker. Defaults to None.
constant (bool, optional): If true, the value is constant. Defaults
to False.
is_data (bool, optional): Flag for paths that are data path (to be
serialized). Defaults to False.
overrides (bool, optional): If True, this argument intentionally
overrides a parent argument. Suppresses the warning that would
otherwise be issued. Defaults to False.
groups (set[ParameterGroup], optional): Set of groups this parameter
belongs to. Used with partial to compute partial identifiers.
Defaults to None (empty set).
"""
required = (field_or_default is None) if required is None else required
if field_or_default is not None and required is not None and required:
raise Exception(
"argument '%s' is required but default value is given" % name
)
self.name = name
self._help = help
self.checker = checker
self.type = type
self.constant = constant
self.ignored = self.type.ignore if ignored is None else ignored
self.required = required
self.objecttype = None
self.is_data = is_data
self.overrides = overrides
self.generator = generator
self.default = None
self.ignore_generated = False
self.ignore_default_in_identifier = False
self.groups = groups if groups else set()
if field_or_default is not None:
assert self.generator is None, (
"generator and field_or_default are exclusive options"
)
if isinstance(field_or_default, field):
self.ignore_generated = field_or_default.ignore_generated
# Allow field to override the overrides flag
if field_or_default.overrides:
self.overrides = True
# Process groups from field
if field_or_default.groups:
self.groups = field_or_default.groups
if field_or_default.default is not None:
self.default = field_or_default.default
self.ignore_default_in_identifier = field_or_default.ignore_default
elif field_or_default.default_factory is not None:
self.generator = field_or_default.default_factory
if not self.ignored:
# For Param fields, eagerly compute default for
# identifier comparison (fixes #191)
self.default = field_or_default.default_factory()
self.ignore_default_in_identifier = (
field_or_default.ignore_default
)
else:
# Bare default: backwards compatible, ignore in identifier
self.default = field_or_default
self.ignore_default_in_identifier = True
assert not self.constant or self.default is not None, (
"Cannot be constant without default"
)
def __repr__(self):
return "Param[{name}:{type}]".format(**self.__dict__)
def validate(self, value):
try:
value = self.type.validate(value)
except TypeError as e:
raise TypeError(f"Value {value} is not valid for argument {self.name}: {e}")
if self.checker:
if not self.checker.check(value):
raise ValueError("Value %s is not valid", value)
return value
def isoutput(self):
if self.generator:
return self.generator.isoutput()
return False
@property
def help(self):
if self._help is None and self.objecttype is not None:
self.objecttype.__parsedoc__()
return self._help
@help.setter
def help(self, help: str):
self._help = help
class ArgumentOptions:
"""Helper class when using type hints"""
def __init__(self):
self.kwargs = {}
self.constant = False
def create(self, name, originaltype, typehint):
from experimaestro.core.types import Type
optionaltype = get_optional(typehint)
type = Type.fromType(optionaltype or typehint)
if (
"field_or_default" not in self.kwargs
or self.kwargs["field_or_default"] is None
):
defaultvalue = getattr(originaltype, name, None)
self.kwargs["field_or_default"] = defaultvalue
# Emit deprecation warning for bare default values (not wrapped in field)
# Skip warning for Constant parameters - they are inherently constant, not defaults
if (
defaultvalue is not None
and not isinstance(defaultvalue, field)
and not self.kwargs.get("constant")
):
module = originaltype.__module__
count = _deprecation_warning_counts.get(module, 0)
if count < _MAX_WARNINGS_PER_MODULE:
_deprecation_warning_counts[module] = count + 1
class_name = originaltype.__qualname__
remaining = _MAX_WARNINGS_PER_MODULE - count - 1
limit_msg = (
f" ({remaining} more warnings for this module)"
if remaining > 0
else " (no more warnings for this module)"
)
warnings.warn(
f"Deprecated: parameter `{name}` in {module}.{class_name} "
f"has an ambiguous default value. Use "
f"`field(default=..., ignore_default=True)` "
f"to keep current behavior (default ignored in identifier) or "
f"`field(default=...)` to include default in identifier. "
f"Run `experimaestro refactor default-values` to fix automatically."
f"{limit_msg}",
DeprecationWarning,
stacklevel=6,
)
self.kwargs["required"] = (optionaltype is None) and (
self.kwargs["field_or_default"] is None
)
return Argument(name, type, **self.kwargs)
class TypeAnnotation:
def __call__(self, options: Optional[ArgumentOptions]):
if options is None:
options = ArgumentOptions()
self.annotate(options)
return options
def annotate(self, options: ArgumentOptions):
pass
class _Param(TypeAnnotation):
"""Base annotation for types"""
def __init__(self, **kwargs):
self.kwargs = kwargs
def annotate(self, options: ArgumentOptions):
options.kwargs.update(self.kwargs)
return options
T = TypeVar("T")
paramHint = _Param()
Param = Annotated[T, paramHint]
"""Type annotation for configuration parameters.
Parameters annotated with ``Param[T]`` are included in the configuration
identifier computation and must be set before the configuration is sealed.
Example::
class MyConfig(Config):
name: Param[str]
count: Param[int] = field(default=10)
threshold: Param[float] = field(ignore_default=0.5)
"""
optionHint = _Param(ignored=True)
Option = Annotated[T, optionHint]
"""Deprecated alias for Meta. Use Meta instead."""
Meta = Annotated[T, optionHint]
"""Type annotation for meta-parameters (ignored in identifier computation).
Use ``Meta[T]`` for parameters that should not affect the task identity,
such as output paths or runtime configuration.
Example::
class MyTask(Task):
# This affects the task identity
learning_rate: Param[float]
# This does not affect the identity
checkpoint_path: Meta[Path] = field(default_factory=PathGenerator("model.pt"))
"""
dataHint = _Param(ignored=True, is_data=True)
DataPath = Annotated[Path, dataHint]
"""Type annotation for data paths that should be serialized.
Use ``DataPath`` for paths that point to data files that should be
preserved when serializing/deserializing a configuration. The path
is copied during serialization.
Example::
class MyConfig(Config):
model_weights: DataPath
"""
OptionalDataPath = Annotated[Optional[Path], dataHint]
"""Optional version of DataPath"""
class field:
"""Specify additional properties for a configuration parameter.
Use ``field()`` to control default value behavior and parameter grouping.
**Default value options and identifier behavior:**
``default``
The parameter has a default value that is **always included** in the
task identifier. Two configs with different values always get different
identifiers, even if one uses the default.
``default_factory``
A callable (zero-argument) that produces the default value. Behaves
like ``default`` — the value is **always included** in the identifier.
On ``Meta`` fields, the callable is invoked at seal time (e.g.
``PathGenerator``).
``ignore_default`` (bool)
When ``True`` and combined with ``default`` or ``default_factory``,
the default value is **excluded** from the identifier when the actual
value equals the default. This is the backwards-compatible behavior
matching bare defaults (``x: Param[int] = 23``, which is deprecated).
Example::
class MyConfig(Config):
# Default always included in identifier
count: Param[int] = field(default=10)
# Factory default always included in identifier
fabric: Param[FabricConfig] = field(
default_factory=FabricConfig.C
)
# Default ignored in identifier when value == default
threshold: Param[float] = field(default=0.5, ignore_default=True)
# Factory default ignored when value == default
fabric: Param[FabricConfig] = field(
default_factory=FabricConfig.C, ignore_default=True
)
# Generated path (Meta field, excluded from identifier)
output: Meta[Path] = field(
default_factory=PathGenerator("out.txt")
)
# Parameter in a group (for partial identifiers)
lr: Param[float] = field(groups=[training_group])
"""
def __init__(
self,
*,
default: Any = None,
default_factory: Callable = None,
ignore_default: bool | Any = None,
ignore_generated=False,
overrides=False,
groups: list["ParameterGroup"] = None,
):
"""Create a field specification.
:param default: Default value, always included in identifier
computation (unless ``ignore_default=True``).
:param default_factory: Callable that generates the default value.
On ``Param`` fields, the factory is called eagerly at class
definition time. The value is always included in the identifier
unless ``ignore_default=True``. On ``Meta`` fields, the callable
is invoked at seal time (use ``PathGenerator`` for
task-directory-relative paths).
:param ignore_default: When ``True``, the default value is excluded
from identifier computation when the actual value equals the
default. Must be used with ``default`` or ``default_factory``.
For backwards compatibility, passing a non-bool value (without
``default`` or ``default_factory``) is treated as
``field(default=value, ignore_default=True)`` but emits a
deprecation warning.
:param ignore_generated: If ``True``, the generated value is not
tracked as a "generated value", suppressing reproducibility
warnings. Controls whether context-dependent generator values
(e.g. ``PathGenerator`` with 2 params) are flagged.
:param overrides: If True, suppress warning when overriding a parent
parameter.
:param groups: List of ParameterGroup objects for partial identifiers.
Used with partial to compute identifiers that exclude certain
groups.
"""
assert not ((default is not None) and (default_factory is not None)), (
"default and default_factory are mutually exclusive options"
)
has_default_source = (default is not None) or (default_factory is not None)
if has_default_source:
# When default or default_factory is set, ignore_default must be bool or None
assert ignore_default is None or isinstance(ignore_default, bool), (
"ignore_default must be True, False, or None when used with "
"default or default_factory"
)
elif ignore_default is not None and not isinstance(ignore_default, bool):
# Legacy path: field(ignore_default=<value>) without default/default_factory
# Treat as field(default=<value>, ignore_default=True)
warnings.warn(
f"Deprecated: field(ignore_default={ignore_default!r}) should be "
f"field(default={ignore_default!r}, ignore_default=True). "
f"The old syntax still works but will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
default = ignore_default
ignore_default = True
self.default_factory = default_factory
self.default = default
self.ignore_default = bool(ignore_default) if ignore_default else False
self.ignore_generated = ignore_generated
self.overrides = overrides
self.groups = set(groups) if groups else set()
class help(TypeAnnotation):
def __init__(self, text: str):
self.text = text
def annotate(self, options: ArgumentOptions):
options.kwargs["help"] = self.text
class ConstantHint(TypeAnnotation):
def annotate(self, options: ArgumentOptions):
options.kwargs["constant"] = True
constantHint = ConstantHint()
Constant = Annotated[T, constantHint]
"""Type annotation for constant (read-only) parameters.
Constants must have a default value and cannot be modified after creation.
They are included in the identifier computation.
Example::
class MyConfig(Config):
version: Constant[str] = "1.0"
"""