oscillode/src/blender_maxwell/utils/pydantic_sympy.py

165 lines
4.7 KiB
Python

import typing as typ
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx
from pydantic_core import core_schema as pyd_core_schema
from . import extra_sympy_units as spux
####################
# - Missing Basics
####################
AllowedSympyExprs = sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix
Complex = typx.Annotated[
complex,
pyd.GetPydanticSchema(
lambda tp, handler: pyd_core_schema.no_info_after_validator_function(
lambda x: x, handler(tp)
)
),
]
####################
# - Custom Pydantic Type for sp.Expr
####################
class _SympyExpr:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: AllowedSympyExprs,
_handler: pyd.GetCoreSchemaHandler,
) -> pyd_core_schema.CoreSchema:
def validate_from_str(value: str) -> AllowedSympyExprs:
if not isinstance(value, str):
return value
try:
expr = sp.sympify(value)
except ValueError as ex:
msg = f'Value {value} is not a `sympify`able string'
raise ValueError(msg) from ex
return expr.subs(spux.ALL_UNIT_SYMBOLS)
def validate_from_expr(value: AllowedSympyExprs) -> AllowedSympyExprs:
if not (isinstance(value, sp.Expr | sp.MatrixBase)):
msg = f'Value {value} is not a `sympy` expression'
raise ValueError(msg)
return value
sympy_expr_schema = pyd_core_schema.chain_schema(
[
pyd_core_schema.no_info_plain_validator_function(validate_from_str),
pyd_core_schema.no_info_plain_validator_function(validate_from_expr),
pyd_core_schema.is_instance_schema(AllowedSympyExprs),
]
)
return pyd_core_schema.json_or_python_schema(
json_schema=sympy_expr_schema,
python_schema=sympy_expr_schema,
serialization=pyd_core_schema.plain_serializer_function_ser_schema(
lambda instance: sp.srepr(instance)
),
)
####################
# - Configurable Expression Validation
####################
SympyExpr = typx.Annotated[
AllowedSympyExprs,
_SympyExpr,
]
def ConstrSympyExpr(
# Feature Class
allow_variables: bool = True,
allow_units: bool = True,
# Structure Class
allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
| None = None,
allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
# Element Class
allowed_symbols: set[sp.Symbol] | None = None,
allowed_units: set[spu.Quantity] | None = None,
# Shape Class
allowed_matrix_shapes: set[tuple[int, int]] | None = None,
):
## See `sympy` predicates:
## - <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
def validate_expr(expr: AllowedSympyExprs):
if not (isinstance(expr, sp.Expr | sp.MatrixBase),):
## NOTE: Must match AllowedSympyExprs union elements.
msg = f"expr '{expr}' is not an allowed Sympy expression ({AllowedSympyExprs})"
raise ValueError(msg)
msgs = set()
# Validate Feature Class
if (not allow_variables) and (len(expr.free_symbols) > 0):
msgs.add(
f'allow_variables={allow_variables} does not match expression {expr}.'
)
if (not allow_units) and spux.uses_units(expr):
msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
# Validate Structure Class
if (
allowed_sets
and isinstance(expr, sp.Expr)
and not any(
{
'integer': expr.is_integer,
'rational': expr.is_rational,
'real': expr.is_real,
'complex': expr.is_complex,
}[allowed_set]
for allowed_set in allowed_sets
)
):
msgs.add(
f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
if allowed_structures and not any(
{
'matrix': isinstance(expr, sp.MatrixBase),
}[allowed_set]
for allowed_set in allowed_structures
if allowed_structures != 'scalar'
):
msgs.add(
f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
# Validate Element Class
if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
msgs.add(
f'allowed_symbols={allowed_symbols} does not match expression {expr}'
)
if allowed_units and spux.get_units(expr).issubset(allowed_units):
msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
# Validate Shape Class
if (
allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
) and expr.shape not in allowed_matrix_shapes:
msgs.add(
f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
)
# Error or Return
if msgs:
raise ValueError(str(msgs))
return expr
return typx.Annotated[
AllowedSympyExprs,
_SympyExpr,
pyd.AfterValidator(validate_expr),
]