import typing as typ import pydantic as pyd import sympy as sp import sympy.physics.units as spu import typing_extensions as typx from pydantic_core import core_schema as pyd_core_schema from . import extra_sympy_units as spux #################### # - Missing Basics #################### AllowedSympyExprs = sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix Complex = typx.Annotated[ complex, pyd.GetPydanticSchema( lambda tp, handler: pyd_core_schema.no_info_after_validator_function( lambda x: x, handler(tp) ) ), ] #################### # - Custom Pydantic Type for sp.Expr #################### class _SympyExpr: @classmethod def __get_pydantic_core_schema__( cls, _source_type: AllowedSympyExprs, _handler: pyd.GetCoreSchemaHandler, ) -> pyd_core_schema.CoreSchema: def validate_from_str(value: str) -> AllowedSympyExprs: if not isinstance(value, str): return value try: expr = sp.sympify(value) except ValueError as ex: msg = f'Value {value} is not a `sympify`able string' raise ValueError(msg) from ex return expr.subs(spux.ALL_UNIT_SYMBOLS) def validate_from_expr(value: AllowedSympyExprs) -> AllowedSympyExprs: if not (isinstance(value, sp.Expr | sp.MatrixBase)): msg = f'Value {value} is not a `sympy` expression' raise ValueError(msg) return value sympy_expr_schema = pyd_core_schema.chain_schema( [ pyd_core_schema.no_info_plain_validator_function(validate_from_str), pyd_core_schema.no_info_plain_validator_function(validate_from_expr), pyd_core_schema.is_instance_schema(AllowedSympyExprs), ] ) return pyd_core_schema.json_or_python_schema( json_schema=sympy_expr_schema, python_schema=sympy_expr_schema, serialization=pyd_core_schema.plain_serializer_function_ser_schema( lambda instance: sp.srepr(instance) ), ) #################### # - Configurable Expression Validation #################### SympyExpr = typx.Annotated[ AllowedSympyExprs, _SympyExpr, ] def ConstrSympyExpr( # Feature Class allow_variables: bool = True, allow_units: bool = True, # Structure Class allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] | None = None, allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, # Element Class allowed_symbols: set[sp.Symbol] | None = None, allowed_units: set[spu.Quantity] | None = None, # Shape Class allowed_matrix_shapes: set[tuple[int, int]] | None = None, ): ## See `sympy` predicates: ## - def validate_expr(expr: AllowedSympyExprs): if not (isinstance(expr, sp.Expr | sp.MatrixBase),): ## NOTE: Must match AllowedSympyExprs union elements. msg = f"expr '{expr}' is not an allowed Sympy expression ({AllowedSympyExprs})" raise ValueError(msg) msgs = set() # Validate Feature Class if (not allow_variables) and (len(expr.free_symbols) > 0): msgs.add( f'allow_variables={allow_variables} does not match expression {expr}.' ) if (not allow_units) and spux.uses_units(expr): msgs.add(f'allow_units={allow_units} does not match expression {expr}.') # Validate Structure Class if ( allowed_sets and isinstance(expr, sp.Expr) and not any( { 'integer': expr.is_integer, 'rational': expr.is_rational, 'real': expr.is_real, 'complex': expr.is_complex, }[allowed_set] for allowed_set in allowed_sets ) ): msgs.add( f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" ) if allowed_structures and not any( { 'matrix': isinstance(expr, sp.MatrixBase), }[allowed_set] for allowed_set in allowed_structures if allowed_structures != 'scalar' ): msgs.add( f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" ) # Validate Element Class if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): msgs.add( f'allowed_symbols={allowed_symbols} does not match expression {expr}' ) if allowed_units and spux.get_units(expr).issubset(allowed_units): msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') # Validate Shape Class if ( allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) ) and expr.shape not in allowed_matrix_shapes: msgs.add( f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' ) # Error or Return if msgs: raise ValueError(str(msgs)) return expr return typx.Annotated[ AllowedSympyExprs, _SympyExpr, pyd.AfterValidator(validate_expr), ]