2024-03-10 11:56:37 +01:00
import typing as typ
import typing_extensions as typx
import pydantic as pyd
from pydantic_core import core_schema as pyd_core_schema
import sympy as sp
import sympy . physics . units as spu
2024-03-13 19:10:54 +01:00
from . import extra_sympy_units as spux
2024-03-10 11:56:37 +01:00
####################
# - Missing Basics
####################
2024-03-13 19:10:54 +01:00
AllowedSympyExprs = sp . Expr | sp . MatrixBase | sp . MutableDenseMatrix
2024-03-10 11:56:37 +01:00
Complex = typx . Annotated [
complex ,
pyd . GetPydanticSchema (
lambda tp , handler : pyd_core_schema . no_info_after_validator_function (
lambda x : x , handler ( tp )
)
) ,
]
####################
# - Custom Pydantic Type for sp.Expr
####################
class _SympyExpr :
@classmethod
def __get_pydantic_core_schema__ (
cls ,
_source_type : AllowedSympyExprs ,
_handler : pyd . GetCoreSchemaHandler ,
) - > pyd_core_schema . CoreSchema :
def validate_from_str ( value : str ) - > AllowedSympyExprs :
if not isinstance ( value , str ) :
2024-03-11 16:35:41 +01:00
return value
2024-03-10 11:56:37 +01:00
try :
2024-03-13 19:10:54 +01:00
expr = sp . sympify ( value )
2024-03-10 11:56:37 +01:00
except ValueError as ex :
msg = f " Value { value } is not a `sympify`able string "
raise ValueError ( msg ) from ex
2024-03-13 19:10:54 +01:00
return expr . subs ( spux . ALL_UNIT_SYMBOLS )
2024-03-10 11:56:37 +01:00
def validate_from_expr ( value : AllowedSympyExprs ) - > AllowedSympyExprs :
if not (
isinstance ( value , sp . Expr )
or isinstance ( value , sp . MatrixBase )
) :
msg = f " Value { value } is not a `sympy` expression "
raise ValueError ( msg )
return value
2024-03-11 16:35:41 +01:00
sympy_expr_schema = pyd_core_schema . chain_schema ( [
2024-03-10 11:56:37 +01:00
pyd_core_schema . no_info_plain_validator_function ( validate_from_str ) ,
2024-03-11 16:35:41 +01:00
pyd_core_schema . no_info_plain_validator_function ( validate_from_expr ) ,
pyd_core_schema . is_instance_schema ( AllowedSympyExprs ) ,
2024-03-10 11:56:37 +01:00
] )
return pyd_core_schema . json_or_python_schema (
2024-03-11 16:35:41 +01:00
json_schema = sympy_expr_schema ,
python_schema = sympy_expr_schema ,
2024-03-10 11:56:37 +01:00
serialization = pyd_core_schema . plain_serializer_function_ser_schema (
lambda instance : str ( instance )
) ,
)
####################
# - Configurable Expression Validation
####################
SympyExpr = typx . Annotated [
AllowedSympyExprs ,
_SympyExpr ,
]
def ConstrSympyExpr (
# Feature Class
allow_variables : bool = True ,
allow_units : bool = True ,
# Structure Class
allowed_sets : set [ typx . Literal [
" integer " , " rational " , " real " , " complex "
] ] | None = None ,
allowed_structures : set [ typx . Literal [
" scalar " , " matrix "
] ] | None = None ,
# Element Class
allowed_symbols : set [ sp . Symbol ] | None = None ,
allowed_units : set [ spu . Quantity ] | None = None ,
# Shape Class
allowed_matrix_shapes : set [ tuple [ int , int ] ] | None = None ,
) :
## See `sympy` predicates:
## - <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
def validate_expr ( expr : AllowedSympyExprs ) :
if not (
isinstance ( expr , sp . Expr )
or isinstance ( expr , sp . MatrixBase ) ,
) :
## NOTE: Must match AllowedSympyExprs union elements.
msg = f " expr ' { expr } ' is not an allowed Sympy expression ( { AllowedSympyExprs } ) "
raise ValueError ( msg )
msgs = set ( )
# Validate Feature Class
if ( not allow_variables ) and ( len ( expr . free_symbols ) > 0 ) :
msgs . add ( f " allow_variables= { allow_variables } does not match expression { expr } . " )
2024-03-13 19:10:54 +01:00
if ( not allow_units ) and spux . uses_units ( expr ) :
2024-03-10 11:56:37 +01:00
msgs . add ( f " allow_units= { allow_units } does not match expression { expr } . " )
# Validate Structure Class
if allowed_sets and isinstance ( expr , sp . Expr ) and not any ( [
{
" integer " : expr . is_integer ,
" rational " : expr . is_rational ,
" real " : expr . is_real ,
" complex " : expr . is_complex ,
} [ allowed_set ]
for allowed_set in allowed_sets
] ) :
msgs . add ( f " allowed_sets= { allowed_sets } does not match expression { expr } (remember to add assumptions to symbols, ex. `x = sp.Symbol( ' x ' , real=True)) " )
if allowed_structures and not any ( [
{
" matrix " : isinstance ( expr , sp . MatrixBase ) ,
} [ allowed_set ]
for allowed_set in allowed_structures
if allowed_structures != " scalar "
] ) :
msgs . add ( f " allowed_structures= { allowed_structures } does not match expression { expr } (remember to add assumptions to symbols, ex. `x = sp.Symbol( ' x ' , real=True)) " )
# Validate Element Class
if allowed_symbols and expr . free_symbols . issubset ( allowed_symbols ) :
msgs . add ( f " allowed_symbols= { allowed_symbols } does not match expression { expr } " )
2024-03-13 19:10:54 +01:00
if allowed_units and spux . get_units ( expr ) . issubset ( allowed_units ) :
2024-03-10 11:56:37 +01:00
msgs . add ( f " allowed_units= { allowed_units } does not match expression { expr } " )
# Validate Shape Class
if (
allowed_matrix_shapes
and isinstance ( expr , sp . MatrixBase )
) and not ( expr . shape in allowed_matrix_shapes ) :
msgs . add ( f " allowed_matrix_shapes= { allowed_matrix_shapes } does not match expression { expr } with shape { expr . shape } " )
# Error or Return
if msgs : raise ValueError ( str ( msgs ) )
return expr
return typx . Annotated [
AllowedSympyExprs ,
_SympyExpr ,
pyd . AfterValidator ( validate_expr ) ,
]