165 lines
4.7 KiB
Python
165 lines
4.7 KiB
Python
import typing as typ
|
|
|
|
import pydantic as pyd
|
|
import sympy as sp
|
|
import sympy.physics.units as spu
|
|
import typing_extensions as typx
|
|
from pydantic_core import core_schema as pyd_core_schema
|
|
|
|
from . import extra_sympy_units as spux
|
|
|
|
####################
|
|
# - Missing Basics
|
|
####################
|
|
AllowedSympyExprs = sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix
|
|
Complex = typx.Annotated[
|
|
complex,
|
|
pyd.GetPydanticSchema(
|
|
lambda tp, handler: pyd_core_schema.no_info_after_validator_function(
|
|
lambda x: x, handler(tp)
|
|
)
|
|
),
|
|
]
|
|
|
|
|
|
####################
|
|
# - Custom Pydantic Type for sp.Expr
|
|
####################
|
|
class _SympyExpr:
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(
|
|
cls,
|
|
_source_type: AllowedSympyExprs,
|
|
_handler: pyd.GetCoreSchemaHandler,
|
|
) -> pyd_core_schema.CoreSchema:
|
|
def validate_from_str(value: str) -> AllowedSympyExprs:
|
|
if not isinstance(value, str):
|
|
return value
|
|
|
|
try:
|
|
expr = sp.sympify(value)
|
|
except ValueError as ex:
|
|
msg = f'Value {value} is not a `sympify`able string'
|
|
raise ValueError(msg) from ex
|
|
|
|
return expr.subs(spux.ALL_UNIT_SYMBOLS)
|
|
|
|
def validate_from_expr(value: AllowedSympyExprs) -> AllowedSympyExprs:
|
|
if not (isinstance(value, sp.Expr | sp.MatrixBase)):
|
|
msg = f'Value {value} is not a `sympy` expression'
|
|
raise ValueError(msg)
|
|
|
|
return value
|
|
|
|
sympy_expr_schema = pyd_core_schema.chain_schema(
|
|
[
|
|
pyd_core_schema.no_info_plain_validator_function(validate_from_str),
|
|
pyd_core_schema.no_info_plain_validator_function(validate_from_expr),
|
|
pyd_core_schema.is_instance_schema(AllowedSympyExprs),
|
|
]
|
|
)
|
|
return pyd_core_schema.json_or_python_schema(
|
|
json_schema=sympy_expr_schema,
|
|
python_schema=sympy_expr_schema,
|
|
serialization=pyd_core_schema.plain_serializer_function_ser_schema(
|
|
lambda instance: sp.srepr(instance)
|
|
),
|
|
)
|
|
|
|
|
|
####################
|
|
# - Configurable Expression Validation
|
|
####################
|
|
SympyExpr = typx.Annotated[
|
|
AllowedSympyExprs,
|
|
_SympyExpr,
|
|
]
|
|
|
|
|
|
def ConstrSympyExpr(
|
|
# Feature Class
|
|
allow_variables: bool = True,
|
|
allow_units: bool = True,
|
|
# Structure Class
|
|
allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
|
|
| None = None,
|
|
allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
|
|
# Element Class
|
|
allowed_symbols: set[sp.Symbol] | None = None,
|
|
allowed_units: set[spu.Quantity] | None = None,
|
|
# Shape Class
|
|
allowed_matrix_shapes: set[tuple[int, int]] | None = None,
|
|
):
|
|
## See `sympy` predicates:
|
|
## - <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
|
|
def validate_expr(expr: AllowedSympyExprs):
|
|
if not (isinstance(expr, sp.Expr | sp.MatrixBase),):
|
|
## NOTE: Must match AllowedSympyExprs union elements.
|
|
msg = f"expr '{expr}' is not an allowed Sympy expression ({AllowedSympyExprs})"
|
|
raise ValueError(msg)
|
|
|
|
msgs = set()
|
|
|
|
# Validate Feature Class
|
|
if (not allow_variables) and (len(expr.free_symbols) > 0):
|
|
msgs.add(
|
|
f'allow_variables={allow_variables} does not match expression {expr}.'
|
|
)
|
|
if (not allow_units) and spux.uses_units(expr):
|
|
msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
|
|
|
|
# Validate Structure Class
|
|
if (
|
|
allowed_sets
|
|
and isinstance(expr, sp.Expr)
|
|
and not any(
|
|
{
|
|
'integer': expr.is_integer,
|
|
'rational': expr.is_rational,
|
|
'real': expr.is_real,
|
|
'complex': expr.is_complex,
|
|
}[allowed_set]
|
|
for allowed_set in allowed_sets
|
|
)
|
|
):
|
|
msgs.add(
|
|
f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
|
|
)
|
|
if allowed_structures and not any(
|
|
{
|
|
'matrix': isinstance(expr, sp.MatrixBase),
|
|
}[allowed_set]
|
|
for allowed_set in allowed_structures
|
|
if allowed_structures != 'scalar'
|
|
):
|
|
msgs.add(
|
|
f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
|
|
)
|
|
|
|
# Validate Element Class
|
|
if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
|
|
msgs.add(
|
|
f'allowed_symbols={allowed_symbols} does not match expression {expr}'
|
|
)
|
|
if allowed_units and spux.get_units(expr).issubset(allowed_units):
|
|
msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
|
|
|
|
# Validate Shape Class
|
|
if (
|
|
allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
|
|
) and expr.shape not in allowed_matrix_shapes:
|
|
msgs.add(
|
|
f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
|
|
)
|
|
|
|
# Error or Return
|
|
if msgs:
|
|
raise ValueError(str(msgs))
|
|
return expr
|
|
|
|
return typx.Annotated[
|
|
AllowedSympyExprs,
|
|
_SympyExpr,
|
|
pyd.AfterValidator(validate_expr),
|
|
]
|