# blender_maxwell # Copyright (C) 2024 blender_maxwell Project Contributors # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . """Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system. Attributes: UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit. This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`. SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions. For simple `isinstance` checks, this should be preferred, as it is most performant. For general use, `SympyExpr` should be preferred. SympyExpr: A `SympyType` that is compatible with `pydantic`, including serialization/deserialization. Should be used via the `ConstrSympyExpr`, which also adds expression validation. """ import enum import functools import sys import typing as typ from fractions import Fraction import jax import jax.numpy as jnp import jaxtyping as jtyp import pydantic as pyd import sympy as sp import sympy.physics.units as spu import typing_extensions as typx from pydantic_core import core_schema as pyd_core_schema from blender_maxwell import contracts as ct from . import logger from .staticproperty import staticproperty log = logger.get(__name__) SympyType = ( sp.Basic | sp.Expr | sp.MatrixBase | sp.MutableDenseMatrix | spu.Quantity | spu.Dimension ) #################### # - Math Type #################### class MathType(enum.StrEnum): """Type identifiers that encompass common sets of mathematical objects.""" Integer = enum.auto() Rational = enum.auto() Real = enum.auto() Complex = enum.auto() @staticmethod def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None: if MathType.Complex in mathtypes: return MathType.Complex if MathType.Real in mathtypes: return MathType.Real if MathType.Rational in mathtypes: return MathType.Rational if MathType.Integer in mathtypes: return MathType.Integer if optional: return None msg = f"Can't combine mathtypes {mathtypes}" raise ValueError(msg) def is_compatible(self, other: typ.Self) -> bool: MT = MathType return ( other in { MT.Integer: [MT.Integer], MT.Rational: [MT.Integer, MT.Rational], MT.Real: [MT.Integer, MT.Rational, MT.Real], MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex], }[self] ) def coerce_compatible_pyobj( self, pyobj: bool | int | Fraction | float | complex ) -> int | Fraction | float | complex: MT = MathType match self: case MT.Integer: return int(pyobj) case MT.Rational if isinstance(pyobj, int): return Fraction(pyobj, 1) case MT.Rational if isinstance(pyobj, Fraction): return pyobj case MT.Real: return float(pyobj) case MT.Complex if isinstance(pyobj, int | Fraction): return complex(float(pyobj), 0) case MT.Complex if isinstance(pyobj, float): return complex(pyobj, 0) @staticmethod def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: if isinstance(sp_obj, sp.MatrixBase): return MathType.combine( *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] ) if sp_obj.is_integer: return MathType.Integer if sp_obj.is_rational: return MathType.Rational if sp_obj.is_real: return MathType.Real if sp_obj.is_complex: return MathType.Complex # Infinities if sp_obj in [sp.oo, -sp.oo]: return MathType.Real ## TODO: Strictly, could be ex. integer... if sp_obj in [sp.zoo, -sp.zoo]: return MathType.Complex if optional: return None msg = f"Can't determine MathType from sympy object: {sp_obj}" raise ValueError(msg) @staticmethod def from_pytype(dtype: type) -> type: return { int: MathType.Integer, Fraction: MathType.Rational, float: MathType.Real, complex: MathType.Complex, }[dtype] @staticmethod def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type: """Deduce the MathType corresponding to a JAX array. We go about this by leveraging that: - `data` is of a homogeneous type. - `data.item(0)` returns a single element of the array w/pure-python type. By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency. Notes: Should also work with numpy arrays. """ return MathType.from_pytype(type(data.item(0))) @staticmethod def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None: if isinstance(obj, bool | int | Fraction | float | complex): return 'pytype' if isinstance(obj, sp.Basic | sp.MatrixBase | sp.MutableDenseMatrix): return 'expr' return None @property def pytype(self) -> type: MT = MathType return { MT.Integer: int, MT.Rational: Fraction, MT.Real: float, MT.Complex: complex, }[self] @property def symbolic_set(self) -> type: MT = MathType return { MT.Integer: sp.Integers, MT.Rational: sp.Rationals, MT.Real: sp.Reals, MT.Complex: sp.Complexes, }[self] @property def inf_finite(self) -> type: """Opinionated finite representation of "infinity" within this `MathType`. These are chosen using `sys.maxsize` and `sys.float_info`. As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated. **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here. Notes: Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. . These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`). In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway. However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context. """ MT = MathType Z = MT.Integer R = MT.Integer return { MT.Integer: (-sys.maxsize, sys.maxsize), MT.Rational: ( Fraction(Z.inf_finite[0], 1), Fraction(Z.inf_finite[1], 1), ), MT.Real: -(sys.float_info.min, sys.float_info.max), MT.Complex: ( complex(R.inf_finite[0], R.inf_finite[0]), complex(R.inf_finite[1], R.inf_finite[1]), ), }[self] @property def sp_symbol_a(self) -> type: MT = MathType return { MT.Integer: sp.Symbol('a', integer=True), MT.Rational: sp.Symbol('a', rational=True), MT.Real: sp.Symbol('a', real=True), MT.Complex: sp.Symbol('a', complex=True), }[self] @staticmethod def to_str(value: typ.Self) -> type: return { MathType.Integer: 'ℤ', MathType.Rational: 'ℚ', MathType.Real: 'ℝ', MathType.Complex: 'ℂ', }[value] @property def label_pretty(self) -> str: return MathType.to_str(self) @staticmethod def to_name(value: typ.Self) -> str: return MathType.to_str(value) @staticmethod def to_icon(value: typ.Self) -> str: return '' def bl_enum_element(self, i: int) -> ct.BLEnumElement: return ( str(self), MathType.to_name(self), MathType.to_name(self), MathType.to_icon(self), i, ) #################### # - Size: 1D #################### class NumberSize1D(enum.StrEnum): """Valid 1D-constrained shape.""" Scalar = enum.auto() Vec2 = enum.auto() Vec3 = enum.auto() Vec4 = enum.auto() @staticmethod def to_name(value: typ.Self) -> str: NS = NumberSize1D return { NS.Scalar: 'Scalar', NS.Vec2: '2D', NS.Vec3: '3D', NS.Vec4: '4D', }[value] @staticmethod def to_icon(value: typ.Self) -> str: NS = NumberSize1D return { NS.Scalar: '', NS.Vec2: '', NS.Vec3: '', NS.Vec4: '', }[value] def bl_enum_element(self, i: int) -> ct.BLEnumElement: return ( str(self), NumberSize1D.to_name(self), NumberSize1D.to_name(self), NumberSize1D.to_icon(self), i, ) @staticmethod def has_shape(shape: tuple[int, ...] | None): return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)] def supports_shape(self, shape: tuple[int, ...] | None): NS = NumberSize1D match self: case NS.Scalar: return shape is None case NS.Vec2: return shape in ((2,), (2, 1)) case NS.Vec3: return shape in ((3,), (3, 1)) case NS.Vec4: return shape in ((4,), (4, 1)) @staticmethod def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: NS = NumberSize1D return { None: NS.Scalar, (2,): NS.Vec2, (3,): NS.Vec3, (4,): NS.Vec4, (2, 1): NS.Vec2, (3, 1): NS.Vec3, (4, 1): NS.Vec4, }[shape] @property def rows(self): NS = NumberSize1D return { NS.Scalar: 1, NS.Vec2: 2, NS.Vec3: 3, NS.Vec4: 4, }[self] @property def cols(self): return 1 @property def shape(self): NS = NumberSize1D return { NS.Scalar: None, NS.Vec2: (2,), NS.Vec3: (3,), NS.Vec4: (4,), }[self] def symbol_range(sym: sp.Symbol) -> str: return f'{sym.name} ∈ ' + ( 'ℂ' if sym.is_complex else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) ) #################### # - Symbol Sizes #################### class SimpleSize2D(enum.StrEnum): """Simple subset of sizes for rank-2 tensors.""" Scalar = enum.auto() # Vectors Vec2 = enum.auto() ## 2x1 Vec3 = enum.auto() ## 3x1 Vec4 = enum.auto() ## 4x1 # Covectors CoVec2 = enum.auto() ## 1x2 CoVec3 = enum.auto() ## 1x3 CoVec4 = enum.auto() ## 1x4 # Square Matrices Mat22 = enum.auto() ## 2x2 Mat33 = enum.auto() ## 3x3 Mat44 = enum.auto() ## 4x4 #################### # - Unit Dimensions #################### class DimsMeta(type): def __getattr__(cls, attr: str) -> spu.Dimension: if ( attr in spu.definitions.dimension_definitions.__dir__() and not attr.startswith('__') ): return getattr(spu.definitions.dimension_definitions, attr) raise AttributeError(name=attr, obj=Dims) class Dims(metaclass=DimsMeta): """Access `sympy.physics.units` dimensions with less hassle. Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. An `AttributeError` is raised if the unit cannot be found in `sympy`. Examples: The objects returned are a direct alias to `sympy`, with less hassle: ```python assert Dims.length == ( sympy.physics.units.definitions.dimension_definitions.length ) ``` """ #################### # - Units #################### femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs') femtosecond.set_global_relative_scale_factor(spu.femto, spu.second) # Length femtometer = fm = spu.Quantity('femtometer', abbrev='fm') femtometer.set_global_relative_scale_factor(spu.femto, spu.meter) # Lum Flux lumen = lm = spu.Quantity('lumen', abbrev='lm') lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian) # Force nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816 nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton) micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816 micronewton.set_global_relative_scale_factor(spu.micro, spu.newton) millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816 micronewton.set_global_relative_scale_factor(spu.milli, spu.newton) # Frequency kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz') kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz') kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz') gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz) terahertz = THz = spu.Quantity('terahertz', abbrev='THz') terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz) petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz') petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) # Pressure millibar = mbar = spu.Quantity('millibar', abbrev='mbar') millibar.set_global_relative_scale_factor(spu.milli, spu.bar) hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) } | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()} #################### # - Expr Analysis: Units #################### ## TODO: Caching w/srepr'ed expression. ## TODO: An LFU cache could do better than an LRU. def uses_units(sp_obj: SympyType) -> bool: """Determines if an expression uses any units. Notes: The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. Depth-first was chosen since `sp.Quantity`s are likelier to be found among individual symbols, rather than complete subexpressions. The **worst-case** runtime is when there are no units, in which case the **entire expression graph will be traversed**. Parameters: expr: The sympy expression that may contain units. Returns: Whether or not there are units used within the expression. """ return sp_obj.has(spu.Quantity) # return any( # isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj) # ) ## TODO: Caching w/srepr'ed expression. ## TODO: An LFU cache could do better than an LRU. def get_units(expr: sp.Expr) -> set[spu.Quantity]: """Finds all units used by the expression, and returns them as a set. No information about _the relationship between units_ is exposed. For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`. Notes: The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node. Parameters: expr: The sympy expression that may contain units. Returns: All units (`spu.Quantity`) used within the expression. """ return { subexpr for subexpr in sp.postorder_traversal(expr) if isinstance(subexpr, spu.Quantity) } def parse_shape(sp_obj: SympyType) -> int | None: if isinstance(sp_obj, sp.MatrixBase): return sp_obj.shape return None #################### # - Pydantic-Validated SympyExpr #################### class _SympyExpr: """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`. Notes: You probably want to use `SympyExpr`. Examples: To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`: ```python SympyExpr = typx.Annotated[SympyType, _SympyExpr] class Spam(pyd.BaseModel): line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3) ``` """ @classmethod def __get_pydantic_core_schema__( cls, _source_type: SympyType, _handler: pyd.GetCoreSchemaHandler, ) -> pyd_core_schema.CoreSchema: """Compute a schema that allows `pydantic` to validate a `sympy` type.""" def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any: """Parse and validate a string expression. Parameters: sp_str: A stringified `sympy` object, that will be parsed to a sympy type. Before use, `isinstance(expr_str, str)` is checked. If the object isn't a string, then the validation will be skipped. Returns: Either a `sympy` object, if the input is parseable, or the same untouched object. Raises: ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. """ # Constrain to String if not isinstance(sp_str, str): return sp_str # Parse String -> Sympy try: expr = sp.sympify(sp_str) except ValueError as ex: msg = f'String {sp_str} is not a valid sympy expression' raise ValueError(msg) from ex # Substitute Symbol -> Quantity return expr.subs(UNIT_BY_SYMBOL) def validate_from_pytype( sp_pytype: int | Fraction | float | complex, ) -> SympyType | typ.Any: """Parse and validate a pure Python type. Parameters: sp_str: A stringified `sympy` object, that will be parsed to a sympy type. Before use, `isinstance(expr_str, str)` is checked. If the object isn't a string, then the validation will be skipped. Returns: Either a `sympy` object, if the input is parseable, or the same untouched object. Raises: ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. """ # Constrain to String if not isinstance(sp_pytype, int | Fraction | float | complex): return sp_pytype if isinstance(sp_pytype, int): return sp.Integer(sp_pytype) if isinstance(sp_pytype, Fraction): return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) if isinstance(sp_pytype, float): return sp.Float(sp_pytype) # sp_pytype => Complex return sp_pytype.real + sp.I * sp_pytype.imag sympy_expr_schema = pyd_core_schema.chain_schema( [ pyd_core_schema.no_info_plain_validator_function(validate_from_str), pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), pyd_core_schema.is_instance_schema(SympyType), ] ) return pyd_core_schema.json_or_python_schema( json_schema=sympy_expr_schema, python_schema=sympy_expr_schema, serialization=pyd_core_schema.plain_serializer_function_ser_schema( lambda sp_obj: sp.srepr(sp_obj) ), ) SympyExpr = typx.Annotated[ sp.Basic, ## Treat all sympy types as sp.Basic _SympyExpr, ] ## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. def ConstrSympyExpr( # noqa: N802, PLR0913 # Features allow_variables: bool = True, allow_units: bool = True, # Structures allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] | None = None, allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, # Element Class max_symbols: int | None = None, allowed_symbols: set[sp.Symbol] | None = None, allowed_units: set[spu.Quantity] | None = None, # Shape Class allowed_matrix_shapes: set[tuple[int, int]] | None = None, ) -> SympyType: """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`. Relies on the `sympy` assumptions system. See Parameters (TBD): Returns: A type that represents a constrained `sympy` expression. """ def validate_expr(expr: SympyType): if not (isinstance(expr, SympyType),): msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})" raise ValueError(msg) msgs = set() # Validate Feature Class if (not allow_variables) and (len(expr.free_symbols) > 0): msgs.add( f'allow_variables={allow_variables} does not match expression {expr}.' ) if (not allow_units) and uses_units(expr): msgs.add(f'allow_units={allow_units} does not match expression {expr}.') # Validate Structure Class if ( allowed_sets and isinstance(expr, sp.Expr) and not any( { 'integer': expr.is_integer, 'rational': expr.is_rational, 'real': expr.is_real, 'complex': expr.is_complex, }[allowed_set] for allowed_set in allowed_sets ) ): msgs.add( f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" ) if allowed_structures and not any( { 'scalar': True, 'matrix': isinstance(expr, sp.MatrixBase), }[allowed_set] for allowed_set in allowed_structures ): msgs.add( f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" ) # Validate Element Class if max_symbols and len(expr.free_symbols) > max_symbols: msgs.add(f'max_symbols={max_symbols} does not match expression {expr}') if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): msgs.add( f'allowed_symbols={allowed_symbols} does not match expression {expr}' ) if allowed_units and get_units(expr).issubset(allowed_units): msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') # Validate Shape Class if ( allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) ) and expr.shape not in allowed_matrix_shapes: msgs.add( f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' ) # Error or Return if msgs: raise ValueError(str(msgs)) return expr return typx.Annotated[ sp.Basic, _SympyExpr, pyd.AfterValidator(validate_expr), ] #################### # - Common ConstrSympyExpr #################### # Expression ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_structures={'scalar'}, allowed_sets={'integer', 'rational', 'real'}, ) ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_structures={'scalar'}, allowed_sets={'integer', 'rational', 'real', 'complex'}, ) # Symbol IntSymbol: typ.TypeAlias = ConstrSympyExpr( allow_variables=True, allow_units=False, allowed_sets={'integer'}, max_symbols=1, ) RationalSymbol: typ.TypeAlias = ConstrSympyExpr( allow_variables=True, allow_units=False, allowed_sets={'integer', 'rational'}, max_symbols=1, ) RealSymbol: typ.TypeAlias = ConstrSympyExpr( allow_variables=True, allow_units=False, allowed_sets={'integer', 'rational', 'real'}, max_symbols=1, ) ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( allow_variables=True, allow_units=False, allowed_sets={'integer', 'rational', 'real', 'complex'}, max_symbols=1, ) Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol # Unit UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension ## Technically a "unit expression", which includes compound types. ## Support for this is the reason to prefer over raw spu.Quantity. Unit: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=True, allowed_structures={'scalar'}, ) # Number IntNumber: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_sets={'integer'}, allowed_structures={'scalar'}, ) RealNumber: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_sets={'integer', 'rational', 'real'}, allowed_structures={'scalar'}, ) ComplexNumber: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_sets={'integer', 'rational', 'real', 'complex'}, allowed_structures={'scalar'}, ) Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber # Number PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=True, allowed_sets={'integer', 'rational', 'real'}, allowed_structures={'scalar'}, ) PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=True, allowed_sets={'integer', 'rational', 'real', 'complex'}, allowed_structures={'scalar'}, ) PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber # Vector Real3DVector: typ.TypeAlias = ConstrSympyExpr( allow_variables=False, allow_units=False, allowed_sets={'integer', 'rational', 'real'}, allowed_structures={'matrix'}, allowed_matrix_shapes={(3, 1)}, ) #################### # - Sympy Utilities: Printing #################### _SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( settings={ 'abbrev': True, } ) def sp_to_str(sp_obj: SympyExpr) -> str: """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. This should be used whenever a **string for UI use** is needed from a `sympy` object. Notes: This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. Parameters: sp_obj: The `sympy` object to convert to a string. Returns: A string representing the expression for human use. _The string is not re-encodable to the expression._ """ ## TODO: A bool flag property that does a lot of find/replace to make it super pretty return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) def pretty_symbol(sym: sp.Symbol) -> str: return f'{sym.name} ∈ ' + ( 'ℤ' if sym.is_integer else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?')) ) #################### # - Unit Utilities #################### def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number: """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. Notes: The unitless output is still an `sp.Expr`, which may contain ex. symbols. If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. Parameters: expr: The unit-containing expression to convert. unit_to: The unit that is converted to. Returns: The unitless part of `expr`, after scaling the entire expression to `unit`. Raises: ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. """ unitless_expr = spu.convert_to(sp_obj, unit) / unit if unit is not None else sp_obj if not uses_units(unitless_expr): return unitless_expr msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' raise ValueError(msg) def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number: """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. Parameters: unit_from: The unit that is converted from. unit_to: The unit that is converted to. Returns: The numerical scaling factor between the two units. Raises: ValueError: If the two units don't share a common dimension. """ if unit_from.dimension == unit_to.dimension: return scale_to_unit(unit_from, unit_to) msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them" raise ValueError(msg) @functools.cache def unit_str_to_unit(unit_str: str) -> Unit | None: # Edge Case: Manually Parse Degrees ## -> sp.sympify('degree') actually produces the sp.degree() function. ## -> Therefore, we must special case this particular unit. if unit_str == 'degree': expr = spu.degree else: expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL) if expr.has(spu.Quantity): return expr msg = f'No valid unit for unit string {unit_str}' raise ValueError(msg) #################### # - "Physical" Type #################### def unit_dim_to_unit_dim_deps( unit_dims: SympyType, ) -> dict[spu.dimensions.Dimension, int] | None: dimsys_SI = spu.systems.si.dimsys_SI # Retrieve Dimensional Dependencies try: return dimsys_SI.get_dimensional_dependencies(unit_dims) # Catch TypeError ## -> Happens if `+` or `-` is in `unit`. ## -> Generally, it doesn't make sense to add/subtract differing unit dims. ## -> Thus, when trying to figure out the unit dimension, there isn't one. except TypeError: return None def unit_to_unit_dim_deps( unit: SympyType, ) -> dict[spu.dimensions.Dimension, int] | None: # Retrieve Dimensional Dependencies ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms. ## -> This is extremely problematic; `Dims` arithmetic has key properties. ## -> So we have to go all the way to the dimensional dependencies. ## -> This isn't really respecting the args, but it seems to work :) return unit_dim_to_unit_dim_deps( unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)}) ) def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool: return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps( unit_dim_r ) def compare_unit_dim_to_unit_dim_deps( unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int] ) -> bool: return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps class PhysicalType(enum.StrEnum): """Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type.""" # Unitless NonPhysical = enum.auto() # Global Time = enum.auto() Angle = enum.auto() SolidAngle = enum.auto() ## TODO: Some kind of 3D-specific orientation ex. a quaternion Freq = enum.auto() AngFreq = enum.auto() ## rad*hertz # Cartesian Length = enum.auto() Area = enum.auto() Volume = enum.auto() # Mechanical Vel = enum.auto() Accel = enum.auto() Mass = enum.auto() Force = enum.auto() Pressure = enum.auto() # Energy Work = enum.auto() ## joule Power = enum.auto() ## watt PowerFlux = enum.auto() ## watt Temp = enum.auto() # Electrodynamics Current = enum.auto() ## ampere CurrentDensity = enum.auto() Charge = enum.auto() ## coulomb Voltage = enum.auto() Capacitance = enum.auto() ## farad Impedance = enum.auto() ## ohm Conductance = enum.auto() ## siemens Conductivity = enum.auto() ## siemens / length MFlux = enum.auto() ## weber MFluxDensity = enum.auto() ## tesla Inductance = enum.auto() ## henry EField = enum.auto() HField = enum.auto() # Luminal LumIntensity = enum.auto() LumFlux = enum.auto() Illuminance = enum.auto() @functools.cached_property def unit_dim(self) -> SympyType: PT = PhysicalType return { PT.NonPhysical: None, # Global PT.Time: Dims.time, PT.Angle: Dims.angle, PT.SolidAngle: spu.steradian.dimension, ## MISSING PT.Freq: Dims.frequency, PT.AngFreq: Dims.angle * Dims.frequency, # Cartesian PT.Length: Dims.length, PT.Area: Dims.length**2, PT.Volume: Dims.length**3, # Mechanical PT.Vel: Dims.length / Dims.time, PT.Accel: Dims.length / Dims.time**2, PT.Mass: Dims.mass, PT.Force: Dims.force, PT.Pressure: Dims.pressure, # Energy PT.Work: Dims.energy, PT.Power: Dims.power, PT.PowerFlux: Dims.power / Dims.length**2, PT.Temp: Dims.temperature, # Electrodynamics PT.Current: Dims.current, PT.CurrentDensity: Dims.current / Dims.length**2, PT.Charge: Dims.charge, PT.Voltage: Dims.voltage, PT.Capacitance: Dims.capacitance, PT.Impedance: Dims.impedance, PT.Conductance: Dims.conductance, PT.Conductivity: Dims.conductance / Dims.length, PT.MFlux: Dims.magnetic_flux, PT.MFluxDensity: Dims.magnetic_density, PT.Inductance: Dims.inductance, PT.EField: Dims.voltage / Dims.length, PT.HField: Dims.current / Dims.length, # Luminal PT.LumIntensity: Dims.luminous_intensity, PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, PT.Illuminance: Dims.luminous_intensity / Dims.length**2, }[self] @staticproperty def unit_dims() -> dict[typ.Self, SympyType]: return { physical_type: physical_type.unit_dim for physical_type in list(PhysicalType) } @functools.cached_property def color(self): """A color corresponding to the physical type. The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented. The LLM provided the following rationale for its choices: > Non-Physical: Grey signifies neutrality and non-physical nature. > Global: > Time: Blue is often associated with calmness and the passage of time. > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects. > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time. > Cartesian: > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension. > Mechanical: > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities. > Mass: Dark red for the fundamental property. > Force and Pressure: Shades of red indicating intensity. > Energy: > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities. > Temperature: Yellow for heat. > Electrodynamics: > Current and related quantities: Cyan shades indicating flow. > Voltage, Capacitance: Greenish and blueish cyan for electrical potential. > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance. > Magnetic properties: Magenta shades for magnetism. > Electric Field: Light blue. > Magnetic Field: Grey, as it can be considered neutral in terms of direction. > Luminal: > Luminous properties: Yellows to signify light and illumination. > > This color mapping helps maintain intuitive connections for users interacting with these physical types. """ PT = PhysicalType return { PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical # Global PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency # Cartesian PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume # Mechanical PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure # Energy PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature # Electrodynamics PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux PT.MFluxDensity: ( 0.85, 0.5, 0.85, 1.0, ), # Light Magenta: Magnetic Flux Density PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field # Luminal PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance }[self] @functools.cached_property def default_unit(self) -> list[Unit]: PT = PhysicalType return { PT.NonPhysical: None, # Global PT.Time: spu.picosecond, PT.Angle: spu.radian, PT.SolidAngle: spu.steradian, PT.Freq: terahertz, PT.AngFreq: spu.radian * terahertz, # Cartesian PT.Length: spu.micrometer, PT.Area: spu.um**2, PT.Volume: spu.um**3, # Mechanical PT.Vel: spu.um / spu.second, PT.Accel: spu.um / spu.second, PT.Mass: spu.microgram, PT.Force: micronewton, PT.Pressure: millibar, # Energy PT.Work: spu.joule, PT.Power: spu.watt, PT.PowerFlux: spu.watt / spu.meter**2, PT.Temp: spu.kelvin, # Electrodynamics PT.Current: spu.ampere, PT.CurrentDensity: spu.ampere / spu.meter**2, PT.Charge: spu.coulomb, PT.Voltage: spu.volt, PT.Capacitance: spu.farad, PT.Impedance: spu.ohm, PT.Conductance: spu.siemens, PT.Conductivity: spu.siemens / spu.micrometer, PT.MFlux: spu.weber, PT.MFluxDensity: spu.tesla, PT.Inductance: spu.henry, PT.EField: spu.volt / spu.micrometer, PT.HField: spu.ampere / spu.micrometer, # Luminal PT.LumIntensity: spu.candela, PT.LumFlux: spu.candela * spu.steradian, PT.Illuminance: spu.candela / spu.meter**2, }[self] @functools.cached_property def valid_units(self) -> list[Unit]: """Retrieve an ordered (by subjective usefulness) list of units for this physical type. Notes: The order in which valid units are declared is the exact same order that UI dropdowns display them. **Altering the order of units breaks backwards compatibility**. """ PT = PhysicalType return { PT.NonPhysical: [None], # Global PT.Time: [ spu.picosecond, femtosecond, spu.nanosecond, spu.microsecond, spu.millisecond, spu.second, spu.minute, spu.hour, spu.day, ], PT.Angle: [ spu.radian, spu.degree, ], PT.SolidAngle: [ spu.steradian, ], PT.Freq: ( _valid_freqs := [ terahertz, spu.hertz, kilohertz, megahertz, gigahertz, petahertz, exahertz, ] ), PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], # Cartesian PT.Length: ( _valid_lens := [ spu.micrometer, spu.nanometer, spu.picometer, spu.angstrom, spu.millimeter, spu.centimeter, spu.meter, spu.inch, spu.foot, spu.yard, spu.mile, ] ), PT.Area: [_unit**2 for _unit in _valid_lens], PT.Volume: [_unit**3 for _unit in _valid_lens], # Mechanical PT.Vel: [_unit / spu.second for _unit in _valid_lens], PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], PT.Mass: [ spu.kilogram, spu.electron_rest_mass, spu.dalton, spu.microgram, spu.milligram, spu.gram, spu.metric_ton, ], PT.Force: [ micronewton, nanonewton, millinewton, spu.newton, spu.kg * spu.meter / spu.second**2, ], PT.Pressure: [ spu.bar, millibar, spu.pascal, hectopascal, spu.atmosphere, spu.psi, spu.mmHg, spu.torr, ], # Energy PT.Work: [ spu.joule, spu.electronvolt, ], PT.Power: [ spu.watt, ], PT.PowerFlux: [ spu.watt / spu.meter**2, ], PT.Temp: [ spu.kelvin, ], # Electrodynamics PT.Current: [ spu.ampere, ], PT.CurrentDensity: [ spu.ampere / spu.meter**2, ], PT.Charge: [ spu.coulomb, ], PT.Voltage: [ spu.volt, ], PT.Capacitance: [ spu.farad, ], PT.Impedance: [ spu.ohm, ], PT.Conductance: [ spu.siemens, ], PT.Conductivity: [ spu.siemens / spu.micrometer, spu.siemens / spu.meter, ], PT.MFlux: [ spu.weber, ], PT.MFluxDensity: [ spu.tesla, ], PT.Inductance: [ spu.henry, ], PT.EField: [ spu.volt / spu.micrometer, spu.volt / spu.meter, ], PT.HField: [ spu.ampere / spu.micrometer, spu.ampere / spu.meter, ], # Luminal PT.LumIntensity: [ spu.candela, ], PT.LumFlux: [ spu.candela * spu.steradian, ], PT.Illuminance: [ spu.candela / spu.meter**2, ], }[self] @staticmethod def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None: """Attempt to determine a matching `PhysicalType` from a unit. NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`. Returns: The matched `PhysicalType`. If none could be matched, then either return `None` (if `optional` is set) or error. Raises: ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. """ if unit is None: return ct.PhysicalType.NonPhysical unit_dim_deps = unit_to_unit_dim_deps(unit) if unit_dim_deps is not None: for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps): return physical_type if optional: return None msg = f'Could not determine PhysicalType for {unit}' raise ValueError(msg) @staticmethod def from_unit_dim( unit_dim: SympyType | None, optional: bool = False ) -> typ.Self | None: """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`. For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`). To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions. Returns: The matched `PhysicalType`. If none could be matched, then either return `None` (if `optional` is set) or error. Raises: ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. """ for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): if compare_unit_dims(unit_dim, candidate_unit_dim): return physical_type if optional: return None msg = f'Could not determine PhysicalType for {unit_dim}' raise ValueError(msg) @functools.cached_property def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]: PT = PhysicalType overrides = { # Cartesian PT.Length: [None, (2,), (3,)], # Mechanical PT.Vel: [None, (2,), (3,)], PT.Accel: [None, (2,), (3,)], PT.Force: [None, (2,), (3,)], # Energy PT.Work: [None, (2,), (3,)], PT.PowerFlux: [None, (2,), (3,)], # Electrodynamics PT.CurrentDensity: [None, (2,), (3,)], PT.MFluxDensity: [None, (2,), (3,)], PT.EField: [None, (2,), (3,)], PT.HField: [None, (2,), (3,)], # Luminal PT.LumFlux: [None, (2,), (3,)], } return overrides.get(self, [None]) @functools.cached_property def valid_mathtypes(self) -> list[MathType]: """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. Generally, all unit quantities are real, in the algebraic mathematical sense. However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. In general, the value is a phasor. While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. Notes: - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. - **Current**/**Voltage**: The imaginary part represents the phase. This also holds for any downstream units. - **Charge**: Generally, it is real. However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. """ MT = MathType PT = PhysicalType overrides = { PT.NonPhysical: list(MT), ## Support All # Cartesian PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping # Mechanical # Energy # Electrodynamics PT.Current: [MT.Real, MT.Complex], ## Im -> Phase PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase PT.EField: [MT.Real, MT.Complex], ## Im -> Phase PT.HField: [MT.Real, MT.Complex], ## Im -> Phase # Luminal } return overrides.get(self, [MT.Real]) @staticmethod def to_name(value: typ.Self) -> str: if value is PhysicalType.NonPhysical: return 'Unitless' return PhysicalType(value).name @staticmethod def to_icon(value: typ.Self) -> str: return '' def bl_enum_element(self, i: int) -> ct.BLEnumElement: PT = PhysicalType return ( str(self), PT.to_name(self), PT.to_name(self), PT.to_icon(self), i, ) #################### # - Standard Unit Systems #################### UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit] _PT = PhysicalType UNITS_SI: UnitSystem = { _PT.NonPhysical: None, # Global _PT.Time: spu.second, _PT.Angle: spu.radian, _PT.SolidAngle: spu.steradian, _PT.Freq: spu.hertz, _PT.AngFreq: spu.radian * spu.hertz, # Cartesian _PT.Length: spu.meter, _PT.Area: spu.meter**2, _PT.Volume: spu.meter**3, # Mechanical _PT.Vel: spu.meter / spu.second, _PT.Accel: spu.meter / spu.second**2, _PT.Mass: spu.kilogram, _PT.Force: spu.newton, # Energy _PT.Work: spu.joule, _PT.Power: spu.watt, _PT.PowerFlux: spu.watt / spu.meter**2, _PT.Temp: spu.kelvin, # Electrodynamics _PT.Current: spu.ampere, _PT.CurrentDensity: spu.ampere / spu.meter**2, _PT.Voltage: spu.volt, _PT.Capacitance: spu.farad, _PT.Impedance: spu.ohm, _PT.Conductance: spu.siemens, _PT.Conductivity: spu.siemens / spu.meter, _PT.MFlux: spu.weber, _PT.MFluxDensity: spu.tesla, _PT.Inductance: spu.henry, _PT.EField: spu.volt / spu.meter, _PT.HField: spu.ampere / spu.meter, # Luminal _PT.LumIntensity: spu.candela, _PT.LumFlux: lumen, _PT.Illuminance: spu.lux, } #################### # - Sympy Utilities: Cast to Python #################### def sympy_to_python( scalar: sp.Basic, use_jax_array: bool = False ) -> int | float | complex | tuple | jax.Array: """Convert a scalar sympy expression to the directly corresponding Python type. Arguments: scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. Returns: A pure Python type that directly corresponds to the input scalar expression. """ if isinstance(scalar, sp.MatrixBase): # Detect Single Column Vector ## --> Flatten to Single Row Vector if len(scalar.shape) == 2 and scalar.shape[1] == 1: _scalar = scalar.T else: _scalar = scalar # Convert to Tuple of Tuples matrix = tuple( [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] ) # Detect Single Row Vector ## --> This could be because the scalar had it. ## --> This could also be because we flattened a column vector. ## Either way, we should strip the pointless dimensions. if len(matrix) == 1: return matrix[0] if not use_jax_array else jnp.array(matrix[0]) return matrix if not use_jax_array else jnp.array(matrix) if scalar.is_integer: return int(scalar) if scalar.is_rational or scalar.is_real: return float(scalar) if scalar.is_complex: return complex(scalar) msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 raise ValueError(msg) #################### # - Convert to Unit System #################### def strip_unit_system( sp_obj: SympyExpr, unit_system: UnitSystem | None = None ) -> SympyExpr: """Strip units occurring in the given unit system from the expression. Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_. Notes: You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. """ if unit_system is None: return sp_obj.subs(UNIT_TO_1) return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) def convert_to_unit_system( sp_obj: SympyExpr, unit_system: UnitSystem | None ) -> SympyExpr: """Convert an expression to the units of a given unit system.""" if unit_system is None: return sp_obj return spu.convert_to( sp_obj, {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, ) def scale_to_unit_system( sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False ) -> int | float | complex | tuple | jax.Array: """Convert an expression to the units of a given unit system, then strip all units of the unit system. Afterwards, it is converted to an appropriate Python type. Notes: For stability, and performance, reasons, this should only be used at the very last stage. Regarding performance: **This is not a fast function**. Parameters: sp_obj: An arbitrary sympy object, presumably with units. unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. Note that, in this context, only `unit_system.values()` is used. Returns: An appropriate pure Python type, after scaling to the unit system and stripping all units away. If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. """ return sympy_to_python( strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system), use_jax_array=use_jax_array, )