"""Constraints for parameters of Problem classes."""
from collections.abc import Callable, Iterable
import dataclasses
from dataclasses import dataclass
from dataclasses import is_dataclass
from enum import auto
from enum import Enum
from enum import Flag
import inspect
import typing
from typing import Any, ClassVar, overload, TypeVar
import numpy as np
Check = Callable[..., None]
class Category(Flag):
"""Category of a constraint."""
Theory = auto()
Implementation = auto()
IMPL = Category.Implementation
"""Violating the constraint, will cause runtime errors / undefined behavior
due to the implementation."""
THEORY = Category.Theory
"""The constraint is not known to cause a runtime error, values outside of
the constraint domain are unphysical and might lead to unphysical domains."""
UNCATEGORIZED = Category(0)
[docs]
class Criticality(Enum):
"""Criticality of a constraint violation."""
Error = auto()
Warning = auto()
[docs]
@dataclass
class Constraint:
"""Constraint for parameters passed to e.g. :py:meth:`engibench.core.Problem.optimize()`."""
check: Check
"""Check callback raising an AssertError if the constraint is violated."""
categories: Category = UNCATEGORIZED
"""Categories of the constraint."""
criticality: Criticality = Criticality.Error
"""Criticality of a violation of the constraint."""
[docs]
def category(self, category: Category) -> "Constraint":
"""Return a copy of the constraint which has the specified category added."""
return Constraint(check=self.check, criticality=self.criticality, categories=self.categories | category)
[docs]
def warning(self) -> "Constraint":
"""Return a copy of the constraint with the criticality level set to "warning"."""
return Constraint(check=self.check, criticality=Criticality.Warning, categories=self.categories)
[docs]
def check_dict(self, parameter_args: dict[str, Any]) -> "Violation | None":
"""Check for a violation of the given constraint for the given parameters."""
# We first inspect the arguments of check callback:
sig = inspect.signature(self.check)
required_args = {
p.name
for p in sig.parameters.values()
if p.default is p.empty and p.kind not in {p.VAR_KEYWORD, p.VAR_POSITIONAL}
}
missing_args = frozenset(required_args) - frozenset(parameter_args)
if missing_args:
msg = f"Missing argument(s) {', '.join(missing_args)} for constraint {self.check.__name__}"
raise TypeError(msg)
# If self.check() accepts `**kwargs`, we feed in **parameter_args:
if any(p.kind is p.VAR_KEYWORD for p in sig.parameters.values()):
constraint_args = parameter_args
# Otherwise, we only feed in parameters defined in the signature:
else:
constraint_args = {param: value for param, value in parameter_args.items() if param in sig.parameters}
try:
self.check(**constraint_args)
except AssertionError as e:
return Violation(self, str(e))
return None
[docs]
def check_value(self, value: Any) -> "Violation | None":
"""Check for a violation for the given single positional value."""
try:
self.check(value)
except AssertionError as e:
return Violation(self, str(e))
return None
@overload
def constraint(check: Check, /) -> Constraint: ...
@overload
def constraint(
*, categories: Category = UNCATEGORIZED, criticality: Criticality = Criticality.Error
) -> Callable[[Check], Constraint]: ...
[docs]
def constraint(
check: Check | None = None,
/,
*,
categories: Category = UNCATEGORIZED,
criticality: Criticality = Criticality.Error,
) -> Callable[[Check], Constraint] | Constraint:
"""Decorator for check callbacks to convert the callback to a :class:`Constraint`."""
if check is not None:
return Constraint(check)
def decorator(check: Check) -> Constraint:
return Constraint(check, categories=categories, criticality=criticality)
return decorator
class Var:
"""Helper class to bind variable names to a constraint."""
def __init__(self, *names: str) -> None:
self.names = names
def check(self, constraint: Constraint) -> Constraint:
"""Bind the variable names to `constraint`."""
name = "_".join([*self.names, constraint.check.__name__])
def extracting_check(**kwargs) -> None:
try:
values = [kwargs[name] for name in self.names]
except KeyError as e:
msg = f"Missing argument {e} for constraint {name}"
raise AssertionError(msg) from None
constraint.check(*values)
extracting_check.__name__ = name
return dataclasses.replace(constraint, check=extracting_check)
@dataclass
class Violation:
"""Representation of a violation of a constraint."""
constraint: Constraint
cause: str
def __str__(self) -> str:
return f"{self.constraint.check.__name__}: {self.cause}"
class Violations:
"""Filterable collection of :class:`Violation` instances returned by :function:`check_constraints`."""
def __init__(self, violations: list[Violation], n_constraints: int) -> None:
self.violations = violations
self.n_constraints = n_constraints
def by_category(self, category: Category) -> "Violations":
"""Filter the violations by the category of the constraint causing the violation."""
return Violations(
[violation for violation in self.violations if category in violation.constraint.categories], self.n_constraints
)
def by_criticality(self, criticality: Criticality) -> "Violations":
"""Filter the violations by criticality."""
return Violations(
[violation for violation in self.violations if violation.constraint.criticality == criticality],
self.n_constraints,
)
def __bool__(self) -> bool:
return bool(self.violations)
def __len__(self) -> int:
return len(self.violations)
def __str__(self) -> str:
return "\n".join(str(v) for v in self.violations)
T = TypeVar("T", int, float)
def bounded(*, lower: T | None = None, upper: T | None = None) -> Constraint:
"""Create a constraint which checks that the specified parameter is contained in an interval `[lower, upper]`."""
def check(value: T) -> None:
msg = f"{value} ∉ [{lower if lower is not None else '-∞'}, {upper if upper is not None else '∞'}]"
assert lower is None or np.all(lower <= value), msg
assert upper is None or np.all(value <= upper), msg
return Constraint(check)
def greater_than(lower: T, /) -> Constraint:
"""Create a constraint which checks that the specified parameter is greater than `lower`."""
def check(value: T) -> None:
assert np.all(value > lower), f"{value} ∉ ({lower}, ∞)"
return Constraint(check)
def less_than(upper: T, /) -> Constraint:
"""Create a constraint which checks that the specified parameter is less than `upper`."""
def check(value: T) -> None:
assert np.all(value < upper), f"{value} ∉ (-∞, {upper})"
return Constraint(check)
def check_optimize_constraints(
constraints: Iterable[Constraint],
design: Any,
config: dict[str, Any],
) -> Violations:
"""Specifically check the arguments of :meth:`engibench.core.Problem.optimize()`."""
return check_constraints(constraints, {"design": design, **config})
def check_constraints(
constraints: Iterable[Constraint],
parameter_args: dict[str, Any],
) -> Violations:
"""Check for violations of the given constraints for the given parameters."""
constraints = list(constraints)
violations = [
violation
for violation in (constraint.check_dict(parameter_args) for constraint in constraints)
if violation is not None
]
return Violations(violations, len(constraints))
def check_field_constraints(
data: Any,
) -> Violations:
"""Check for violations of constraints for fields of a dataclass which declare constraints via :function:`field_constraints`."""
assert is_dataclass(data)
assert not isinstance(data, type)
violations = []
n_constraints = 0
for f_name, constraint in field_constraints(data):
n_constraints += 1
violation = (
constraint.check_value(getattr(data, f_name))
if f_name is not None
else constraint.check_dict(dataclasses.asdict(data))
)
if violation is not None:
if f_name is not None:
violation = Violation(violation.constraint, f"{type(data).__name__}.{f_name}: {violation.cause}")
violations.append(violation)
return Violations(violations, n_constraints)
def field_constraints(data: Any) -> Iterable[tuple[str | None, Constraint]]:
"""Iterate over all constraints declared on the dataclass instance `data`."""
assert is_dataclass(data)
# Check for annotated ClassVar:
for f_name, f in data.__annotations__.items():
if typing.get_origin(f) is not ClassVar:
continue
try:
(annotation,) = typing.get_args(f)
except TypeError:
continue
yield from ((f_name, c) for c in getattr(annotation, "__metadata__", ()) if isinstance(c, Constraint))
for f in dataclasses.fields(data):
# Check for typing.Annotated:
yield from ((f.name, c) for c in getattr(f.type, "__metadata__", ()) if isinstance(c, Constraint))
yield from ((None, c) for c in vars(type(data)).values() if isinstance(c, Constraint))
def count_field_constraints(data: Any) -> int:
"""Return the number of constraints declared on the dataclass `data`."""
return len(list(field_constraints(data)))