oscillode/src/blender_maxwell/utils/extra_sympy_units.py

1696 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system.
Attributes:
UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit.
This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`.
SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions.
For simple `isinstance` checks, this should be preferred, as it is most performant.
For general use, `SympyExpr` should be preferred.
SympyExpr: A `SympyType` that is compatible with `pydantic`, including serialization/deserialization.
Should be used via the `ConstrSympyExpr`, which also adds expression validation.
"""
import enum
import functools
import sys
import typing as typ
from fractions import Fraction
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx
from pydantic_core import core_schema as pyd_core_schema
from blender_maxwell import contracts as ct
from . import logger
from .staticproperty import staticproperty
log = logger.get(__name__)
SympyType = (
sp.Basic
| sp.Expr
| sp.MatrixBase
| sp.MutableDenseMatrix
| spu.Quantity
| spu.Dimension
)
####################
# - Math Type
####################
class MathType(enum.StrEnum):
"""Type identifiers that encompass common sets of mathematical objects."""
Integer = enum.auto()
Rational = enum.auto()
Real = enum.auto()
Complex = enum.auto()
@staticmethod
def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None:
if MathType.Complex in mathtypes:
return MathType.Complex
if MathType.Real in mathtypes:
return MathType.Real
if MathType.Rational in mathtypes:
return MathType.Rational
if MathType.Integer in mathtypes:
return MathType.Integer
if optional:
return None
msg = f"Can't combine mathtypes {mathtypes}"
raise ValueError(msg)
def is_compatible(self, other: typ.Self) -> bool:
MT = MathType
return (
other
in {
MT.Integer: [MT.Integer],
MT.Rational: [MT.Integer, MT.Rational],
MT.Real: [MT.Integer, MT.Rational, MT.Real],
MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex],
}[self]
)
def coerce_compatible_pyobj(
self, pyobj: bool | int | Fraction | float | complex
) -> int | Fraction | float | complex:
MT = MathType
match self:
case MT.Integer:
return int(pyobj)
case MT.Rational if isinstance(pyobj, int):
return Fraction(pyobj, 1)
case MT.Rational if isinstance(pyobj, Fraction):
return pyobj
case MT.Real:
return float(pyobj)
case MT.Complex if isinstance(pyobj, int | Fraction):
return complex(float(pyobj), 0)
case MT.Complex if isinstance(pyobj, float):
return complex(pyobj, 0)
@staticmethod
def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None:
if isinstance(sp_obj, sp.MatrixBase):
return MathType.combine(
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
)
if sp_obj.is_integer:
return MathType.Integer
if sp_obj.is_rational:
return MathType.Rational
if sp_obj.is_real:
return MathType.Real
if sp_obj.is_complex:
return MathType.Complex
# Infinities
if sp_obj in [sp.oo, -sp.oo]:
return MathType.Real ## TODO: Strictly, could be ex. integer...
if sp_obj in [sp.zoo, -sp.zoo]:
return MathType.Complex
if optional:
return None
msg = f"Can't determine MathType from sympy object: {sp_obj}"
raise ValueError(msg)
@staticmethod
def from_pytype(dtype: type) -> type:
return {
int: MathType.Integer,
Fraction: MathType.Rational,
float: MathType.Real,
complex: MathType.Complex,
}[dtype]
@staticmethod
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
"""Deduce the MathType corresponding to a JAX array.
We go about this by leveraging that:
- `data` is of a homogeneous type.
- `data.item(0)` returns a single element of the array w/pure-python type.
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
Notes:
Should also work with numpy arrays.
"""
return MathType.from_pytype(type(data.item(0)))
@staticmethod
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
if isinstance(obj, bool | int | Fraction | float | complex):
return 'pytype'
if isinstance(obj, sp.Basic | sp.MatrixBase | sp.MutableDenseMatrix):
return 'expr'
return None
@property
def pytype(self) -> type:
MT = MathType
return {
MT.Integer: int,
MT.Rational: Fraction,
MT.Real: float,
MT.Complex: complex,
}[self]
@property
def symbolic_set(self) -> type:
MT = MathType
return {
MT.Integer: sp.Integers,
MT.Rational: sp.Rationals,
MT.Real: sp.Reals,
MT.Complex: sp.Complexes,
}[self]
@property
def inf_finite(self) -> type:
"""Opinionated finite representation of "infinity" within this `MathType`.
These are chosen using `sys.maxsize` and `sys.float_info`.
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
Notes:
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
"""
MT = MathType
Z = MT.Integer
R = MT.Integer
return {
MT.Integer: (-sys.maxsize, sys.maxsize),
MT.Rational: (
Fraction(Z.inf_finite[0], 1),
Fraction(Z.inf_finite[1], 1),
),
MT.Real: -(sys.float_info.min, sys.float_info.max),
MT.Complex: (
complex(R.inf_finite[0], R.inf_finite[0]),
complex(R.inf_finite[1], R.inf_finite[1]),
),
}[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
return {
MT.Integer: sp.Symbol('a', integer=True),
MT.Rational: sp.Symbol('a', rational=True),
MT.Real: sp.Symbol('a', real=True),
MT.Complex: sp.Symbol('a', complex=True),
}[self]
@staticmethod
def to_str(value: typ.Self) -> type:
return {
MathType.Integer: '',
MathType.Rational: '',
MathType.Real: '',
MathType.Complex: '',
}[value]
@property
def label_pretty(self) -> str:
return MathType.to_str(self)
@staticmethod
def to_name(value: typ.Self) -> str:
return MathType.to_str(value)
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
return (
str(self),
MathType.to_name(self),
MathType.to_name(self),
MathType.to_icon(self),
i,
)
####################
# - Size: 1D
####################
class NumberSize1D(enum.StrEnum):
"""Valid 1D-constrained shape."""
Scalar = enum.auto()
Vec2 = enum.auto()
Vec3 = enum.auto()
Vec4 = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
NS = NumberSize1D
return {
NS.Scalar: 'Scalar',
NS.Vec2: '2D',
NS.Vec3: '3D',
NS.Vec4: '4D',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
NS = NumberSize1D
return {
NS.Scalar: '',
NS.Vec2: '',
NS.Vec3: '',
NS.Vec4: '',
}[value]
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
return (
str(self),
NumberSize1D.to_name(self),
NumberSize1D.to_name(self),
NumberSize1D.to_icon(self),
i,
)
@staticmethod
def has_shape(shape: tuple[int, ...] | None):
return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)]
def supports_shape(self, shape: tuple[int, ...] | None):
NS = NumberSize1D
match self:
case NS.Scalar:
return shape is None
case NS.Vec2:
return shape in ((2,), (2, 1))
case NS.Vec3:
return shape in ((3,), (3, 1))
case NS.Vec4:
return shape in ((4,), (4, 1))
@staticmethod
def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self:
NS = NumberSize1D
return {
None: NS.Scalar,
(2,): NS.Vec2,
(3,): NS.Vec3,
(4,): NS.Vec4,
(2, 1): NS.Vec2,
(3, 1): NS.Vec3,
(4, 1): NS.Vec4,
}[shape]
@property
def rows(self):
NS = NumberSize1D
return {
NS.Scalar: 1,
NS.Vec2: 2,
NS.Vec3: 3,
NS.Vec4: 4,
}[self]
@property
def cols(self):
return 1
@property
def shape(self):
NS = NumberSize1D
return {
NS.Scalar: None,
NS.Vec2: (2,),
NS.Vec3: (3,),
NS.Vec4: (4,),
}[self]
def symbol_range(sym: sp.Symbol) -> str:
return f'{sym.name}' + (
''
if sym.is_complex
else ('' if sym.is_real else ('' if sym.is_integer else '?'))
)
####################
# - Symbol Sizes
####################
class SimpleSize2D(enum.StrEnum):
"""Simple subset of sizes for rank-2 tensors."""
Scalar = enum.auto()
# Vectors
Vec2 = enum.auto() ## 2x1
Vec3 = enum.auto() ## 3x1
Vec4 = enum.auto() ## 4x1
# Covectors
CoVec2 = enum.auto() ## 1x2
CoVec3 = enum.auto() ## 1x3
CoVec4 = enum.auto() ## 1x4
# Square Matrices
Mat22 = enum.auto() ## 2x2
Mat33 = enum.auto() ## 3x3
Mat44 = enum.auto() ## 4x4
####################
# - Unit Dimensions
####################
class DimsMeta(type):
def __getattr__(cls, attr: str) -> spu.Dimension:
if (
attr in spu.definitions.dimension_definitions.__dir__()
and not attr.startswith('__')
):
return getattr(spu.definitions.dimension_definitions, attr)
raise AttributeError(name=attr, obj=Dims)
class Dims(metaclass=DimsMeta):
"""Access `sympy.physics.units` dimensions with less hassle.
Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`.
An `AttributeError` is raised if the unit cannot be found in `sympy`.
Examples:
The objects returned are a direct alias to `sympy`, with less hassle:
```python
assert Dims.length == (
sympy.physics.units.definitions.dimension_definitions.length
)
```
"""
####################
# - Units
####################
femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs')
femtosecond.set_global_relative_scale_factor(spu.femto, spu.second)
# Length
femtometer = fm = spu.Quantity('femtometer', abbrev='fm')
femtometer.set_global_relative_scale_factor(spu.femto, spu.meter)
# Lum Flux
lumen = lm = spu.Quantity('lumen', abbrev='lm')
lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian)
# Force
nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816
nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton)
micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816
micronewton.set_global_relative_scale_factor(spu.micro, spu.newton)
millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816
micronewton.set_global_relative_scale_factor(spu.milli, spu.newton)
# Frequency
kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz')
kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz')
kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz')
gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz)
terahertz = THz = spu.Quantity('terahertz', abbrev='THz')
terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz)
petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz')
petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz)
exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz')
exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz)
# Pressure
millibar = mbar = spu.Quantity('millibar', abbrev='mbar')
millibar.set_global_relative_scale_factor(spu.milli, spu.bar)
hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816
hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal)
UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
####################
# - Expr Analysis: Units
####################
## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def uses_units(sp_obj: SympyType) -> bool:
"""Determines if an expression uses any units.
Notes:
The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
Depth-first was chosen since `sp.Quantity`s are likelier to be found among individual symbols, rather than complete subexpressions.
The **worst-case** runtime is when there are no units, in which case the **entire expression graph will be traversed**.
Parameters:
expr: The sympy expression that may contain units.
Returns:
Whether or not there are units used within the expression.
"""
return sp_obj.has(spu.Quantity)
# return any(
# isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj)
# )
## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def get_units(expr: sp.Expr) -> set[spu.Quantity]:
"""Finds all units used by the expression, and returns them as a set.
No information about _the relationship between units_ is exposed.
For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`.
Notes:
The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node.
Parameters:
expr: The sympy expression that may contain units.
Returns:
All units (`spu.Quantity`) used within the expression.
"""
return {
subexpr
for subexpr in sp.postorder_traversal(expr)
if isinstance(subexpr, spu.Quantity)
}
def parse_shape(sp_obj: SympyType) -> int | None:
if isinstance(sp_obj, sp.MatrixBase):
return sp_obj.shape
return None
####################
# - Pydantic-Validated SympyExpr
####################
class _SympyExpr:
"""Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`.
Notes:
You probably want to use `SympyExpr`.
Examples:
To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`:
```python
SympyExpr = typx.Annotated[SympyType, _SympyExpr]
class Spam(pyd.BaseModel):
line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3)
```
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: SympyType,
_handler: pyd.GetCoreSchemaHandler,
) -> pyd_core_schema.CoreSchema:
"""Compute a schema that allows `pydantic` to validate a `sympy` type."""
def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any:
"""Parse and validate a string expression.
Parameters:
sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
Before use, `isinstance(expr_str, str)` is checked.
If the object isn't a string, then the validation will be skipped.
Returns:
Either a `sympy` object, if the input is parseable, or the same untouched object.
Raises:
ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
"""
# Constrain to String
if not isinstance(sp_str, str):
return sp_str
# Parse String -> Sympy
try:
expr = sp.sympify(sp_str)
except ValueError as ex:
msg = f'String {sp_str} is not a valid sympy expression'
raise ValueError(msg) from ex
# Substitute Symbol -> Quantity
return expr.subs(UNIT_BY_SYMBOL)
def validate_from_pytype(
sp_pytype: int | Fraction | float | complex,
) -> SympyType | typ.Any:
"""Parse and validate a pure Python type.
Parameters:
sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
Before use, `isinstance(expr_str, str)` is checked.
If the object isn't a string, then the validation will be skipped.
Returns:
Either a `sympy` object, if the input is parseable, or the same untouched object.
Raises:
ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
"""
# Constrain to String
if not isinstance(sp_pytype, int | Fraction | float | complex):
return sp_pytype
if isinstance(sp_pytype, int):
return sp.Integer(sp_pytype)
if isinstance(sp_pytype, Fraction):
return sp.Rational(sp_pytype.numerator, sp_pytype.denominator)
if isinstance(sp_pytype, float):
return sp.Float(sp_pytype)
# sp_pytype => Complex
return sp_pytype.real + sp.I * sp_pytype.imag
sympy_expr_schema = pyd_core_schema.chain_schema(
[
pyd_core_schema.no_info_plain_validator_function(validate_from_str),
pyd_core_schema.no_info_plain_validator_function(validate_from_pytype),
pyd_core_schema.is_instance_schema(SympyType),
]
)
return pyd_core_schema.json_or_python_schema(
json_schema=sympy_expr_schema,
python_schema=sympy_expr_schema,
serialization=pyd_core_schema.plain_serializer_function_ser_schema(
lambda sp_obj: sp.srepr(sp_obj)
),
)
SympyExpr = typx.Annotated[
sp.Basic, ## Treat all sympy types as sp.Basic
_SympyExpr,
]
## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate.
def ConstrSympyExpr( # noqa: N802, PLR0913
# Features
allow_variables: bool = True,
allow_units: bool = True,
# Structures
allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
| None = None,
allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
# Element Class
max_symbols: int | None = None,
allowed_symbols: set[sp.Symbol] | None = None,
allowed_units: set[spu.Quantity] | None = None,
# Shape Class
allowed_matrix_shapes: set[tuple[int, int]] | None = None,
) -> SympyType:
"""Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`.
Relies on the `sympy` assumptions system.
See <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
Parameters (TBD):
Returns:
A type that represents a constrained `sympy` expression.
"""
def validate_expr(expr: SympyType):
if not (isinstance(expr, SympyType),):
msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})"
raise ValueError(msg)
msgs = set()
# Validate Feature Class
if (not allow_variables) and (len(expr.free_symbols) > 0):
msgs.add(
f'allow_variables={allow_variables} does not match expression {expr}.'
)
if (not allow_units) and uses_units(expr):
msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
# Validate Structure Class
if (
allowed_sets
and isinstance(expr, sp.Expr)
and not any(
{
'integer': expr.is_integer,
'rational': expr.is_rational,
'real': expr.is_real,
'complex': expr.is_complex,
}[allowed_set]
for allowed_set in allowed_sets
)
):
msgs.add(
f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
if allowed_structures and not any(
{
'scalar': True,
'matrix': isinstance(expr, sp.MatrixBase),
}[allowed_set]
for allowed_set in allowed_structures
):
msgs.add(
f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
# Validate Element Class
if max_symbols and len(expr.free_symbols) > max_symbols:
msgs.add(f'max_symbols={max_symbols} does not match expression {expr}')
if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
msgs.add(
f'allowed_symbols={allowed_symbols} does not match expression {expr}'
)
if allowed_units and get_units(expr).issubset(allowed_units):
msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
# Validate Shape Class
if (
allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
) and expr.shape not in allowed_matrix_shapes:
msgs.add(
f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
)
# Error or Return
if msgs:
raise ValueError(str(msgs))
return expr
return typx.Annotated[
sp.Basic,
_SympyExpr,
pyd.AfterValidator(validate_expr),
]
####################
# - Common ConstrSympyExpr
####################
# Expression
ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_structures={'scalar'},
allowed_sets={'integer', 'rational', 'real'},
)
ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_structures={'scalar'},
allowed_sets={'integer', 'rational', 'real', 'complex'},
)
# Symbol
IntSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer'},
max_symbols=1,
)
RationalSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational'},
max_symbols=1,
)
RealSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
max_symbols=1,
)
ComplexSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational', 'real', 'complex'},
max_symbols=1,
)
Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol
# Unit
UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension
## Technically a "unit expression", which includes compound types.
## Support for this is the reason to prefer over raw spu.Quantity.
Unit: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_structures={'scalar'},
)
# Number
IntNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer'},
allowed_structures={'scalar'},
)
RealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
# Number
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
# Vector
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'matrix'},
allowed_matrix_shapes={(3, 1)},
)
####################
# - Sympy Utilities: Printing
####################
_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter(
settings={
'abbrev': True,
}
)
def sp_to_str(sp_obj: SympyExpr) -> str:
"""Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter.
This should be used whenever a **string for UI use** is needed from a `sympy` object.
Notes:
This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression.
For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation.
Parameters:
sp_obj: The `sympy` object to convert to a string.
Returns:
A string representing the expression for human use.
_The string is not re-encodable to the expression._
"""
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
def pretty_symbol(sym: sp.Symbol) -> str:
return f'{sym.name}' + (
''
if sym.is_integer
else ('' if sym.is_real else ('' if sym.is_complex else '?'))
)
####################
# - Unit Utilities
####################
def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number:
"""Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value.
This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**.
Notes:
The unitless output is still an `sp.Expr`, which may contain ex. symbols.
If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type.
In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem.
Parameters:
expr: The unit-containing expression to convert.
unit_to: The unit that is converted to.
Returns:
The unitless part of `expr`, after scaling the entire expression to `unit`.
Raises:
ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`.
"""
unitless_expr = spu.convert_to(sp_obj, unit) / unit if unit is not None else sp_obj
if not uses_units(unitless_expr):
return unitless_expr
msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"'
raise ValueError(msg)
def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number:
"""Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another.
Parameters:
unit_from: The unit that is converted from.
unit_to: The unit that is converted to.
Returns:
The numerical scaling factor between the two units.
Raises:
ValueError: If the two units don't share a common dimension.
"""
if unit_from.dimension == unit_to.dimension:
return scale_to_unit(unit_from, unit_to)
msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them"
raise ValueError(msg)
@functools.cache
def unit_str_to_unit(unit_str: str) -> Unit | None:
# Edge Case: Manually Parse Degrees
## -> sp.sympify('degree') actually produces the sp.degree() function.
## -> Therefore, we must special case this particular unit.
if unit_str == 'degree':
expr = spu.degree
else:
expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL)
if expr.has(spu.Quantity):
return expr
msg = f'No valid unit for unit string {unit_str}'
raise ValueError(msg)
####################
# - "Physical" Type
####################
def unit_dim_to_unit_dim_deps(
unit_dims: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
dimsys_SI = spu.systems.si.dimsys_SI
# Retrieve Dimensional Dependencies
try:
return dimsys_SI.get_dimensional_dependencies(unit_dims)
# Catch TypeError
## -> Happens if `+` or `-` is in `unit`.
## -> Generally, it doesn't make sense to add/subtract differing unit dims.
## -> Thus, when trying to figure out the unit dimension, there isn't one.
except TypeError:
return None
def unit_to_unit_dim_deps(
unit: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
# Retrieve Dimensional Dependencies
## -> NOTE: .subs() alone seems to produce sp.Symbol atoms.
## -> This is extremely problematic; `Dims` arithmetic has key properties.
## -> So we have to go all the way to the dimensional dependencies.
## -> This isn't really respecting the args, but it seems to work :)
return unit_dim_to_unit_dim_deps(
unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)})
)
def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps(
unit_dim_r
)
def compare_unit_dim_to_unit_dim_deps(
unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int]
) -> bool:
return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps
class PhysicalType(enum.StrEnum):
"""Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type."""
# Unitless
NonPhysical = enum.auto()
# Global
Time = enum.auto()
Angle = enum.auto()
SolidAngle = enum.auto()
## TODO: Some kind of 3D-specific orientation ex. a quaternion
Freq = enum.auto()
AngFreq = enum.auto() ## rad*hertz
# Cartesian
Length = enum.auto()
Area = enum.auto()
Volume = enum.auto()
# Mechanical
Vel = enum.auto()
Accel = enum.auto()
Mass = enum.auto()
Force = enum.auto()
Pressure = enum.auto()
# Energy
Work = enum.auto() ## joule
Power = enum.auto() ## watt
PowerFlux = enum.auto() ## watt
Temp = enum.auto()
# Electrodynamics
Current = enum.auto() ## ampere
CurrentDensity = enum.auto()
Charge = enum.auto() ## coulomb
Voltage = enum.auto()
Capacitance = enum.auto() ## farad
Impedance = enum.auto() ## ohm
Conductance = enum.auto() ## siemens
Conductivity = enum.auto() ## siemens / length
MFlux = enum.auto() ## weber
MFluxDensity = enum.auto() ## tesla
Inductance = enum.auto() ## henry
EField = enum.auto()
HField = enum.auto()
# Luminal
LumIntensity = enum.auto()
LumFlux = enum.auto()
Illuminance = enum.auto()
@functools.cached_property
def unit_dim(self) -> SympyType:
PT = PhysicalType
return {
PT.NonPhysical: None,
# Global
PT.Time: Dims.time,
PT.Angle: Dims.angle,
PT.SolidAngle: spu.steradian.dimension, ## MISSING
PT.Freq: Dims.frequency,
PT.AngFreq: Dims.angle * Dims.frequency,
# Cartesian
PT.Length: Dims.length,
PT.Area: Dims.length**2,
PT.Volume: Dims.length**3,
# Mechanical
PT.Vel: Dims.length / Dims.time,
PT.Accel: Dims.length / Dims.time**2,
PT.Mass: Dims.mass,
PT.Force: Dims.force,
PT.Pressure: Dims.pressure,
# Energy
PT.Work: Dims.energy,
PT.Power: Dims.power,
PT.PowerFlux: Dims.power / Dims.length**2,
PT.Temp: Dims.temperature,
# Electrodynamics
PT.Current: Dims.current,
PT.CurrentDensity: Dims.current / Dims.length**2,
PT.Charge: Dims.charge,
PT.Voltage: Dims.voltage,
PT.Capacitance: Dims.capacitance,
PT.Impedance: Dims.impedance,
PT.Conductance: Dims.conductance,
PT.Conductivity: Dims.conductance / Dims.length,
PT.MFlux: Dims.magnetic_flux,
PT.MFluxDensity: Dims.magnetic_density,
PT.Inductance: Dims.inductance,
PT.EField: Dims.voltage / Dims.length,
PT.HField: Dims.current / Dims.length,
# Luminal
PT.LumIntensity: Dims.luminous_intensity,
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
}[self]
@staticproperty
def unit_dims() -> dict[typ.Self, SympyType]:
return {
physical_type: physical_type.unit_dim
for physical_type in list(PhysicalType)
}
@functools.cached_property
def color(self):
"""A color corresponding to the physical type.
The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented.
The LLM provided the following rationale for its choices:
> Non-Physical: Grey signifies neutrality and non-physical nature.
> Global:
> Time: Blue is often associated with calmness and the passage of time.
> Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects.
> Frequency and Angular Frequency: Darker shades of blue to maintain the link to time.
> Cartesian:
> Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension.
> Mechanical:
> Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities.
> Mass: Dark red for the fundamental property.
> Force and Pressure: Shades of red indicating intensity.
> Energy:
> Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities.
> Temperature: Yellow for heat.
> Electrodynamics:
> Current and related quantities: Cyan shades indicating flow.
> Voltage, Capacitance: Greenish and blueish cyan for electrical potential.
> Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance.
> Magnetic properties: Magenta shades for magnetism.
> Electric Field: Light blue.
> Magnetic Field: Grey, as it can be considered neutral in terms of direction.
> Luminal:
> Luminous properties: Yellows to signify light and illumination.
>
> This color mapping helps maintain intuitive connections for users interacting with these physical types.
"""
PT = PhysicalType
return {
PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical
# Global
PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time
PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle
PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle
PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency
PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency
# Cartesian
PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length
PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area
PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume
# Mechanical
PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity
PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration
PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass
PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force
PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure
# Energy
PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work
PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power
PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux
PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature
# Electrodynamics
PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current
PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density
PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge
PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage
PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance
PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance
PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance
PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity
PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux
PT.MFluxDensity: (
0.85,
0.5,
0.85,
1.0,
), # Light Magenta: Magnetic Flux Density
PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance
PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field
PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field
# Luminal
PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity
PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux
PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance
}[self]
@functools.cached_property
def default_unit(self) -> list[Unit]:
PT = PhysicalType
return {
PT.NonPhysical: None,
# Global
PT.Time: spu.picosecond,
PT.Angle: spu.radian,
PT.SolidAngle: spu.steradian,
PT.Freq: terahertz,
PT.AngFreq: spu.radian * terahertz,
# Cartesian
PT.Length: spu.micrometer,
PT.Area: spu.um**2,
PT.Volume: spu.um**3,
# Mechanical
PT.Vel: spu.um / spu.second,
PT.Accel: spu.um / spu.second,
PT.Mass: spu.microgram,
PT.Force: micronewton,
PT.Pressure: millibar,
# Energy
PT.Work: spu.joule,
PT.Power: spu.watt,
PT.PowerFlux: spu.watt / spu.meter**2,
PT.Temp: spu.kelvin,
# Electrodynamics
PT.Current: spu.ampere,
PT.CurrentDensity: spu.ampere / spu.meter**2,
PT.Charge: spu.coulomb,
PT.Voltage: spu.volt,
PT.Capacitance: spu.farad,
PT.Impedance: spu.ohm,
PT.Conductance: spu.siemens,
PT.Conductivity: spu.siemens / spu.micrometer,
PT.MFlux: spu.weber,
PT.MFluxDensity: spu.tesla,
PT.Inductance: spu.henry,
PT.EField: spu.volt / spu.micrometer,
PT.HField: spu.ampere / spu.micrometer,
# Luminal
PT.LumIntensity: spu.candela,
PT.LumFlux: spu.candela * spu.steradian,
PT.Illuminance: spu.candela / spu.meter**2,
}[self]
@functools.cached_property
def valid_units(self) -> list[Unit]:
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
Notes:
The order in which valid units are declared is the exact same order that UI dropdowns display them.
**Altering the order of units breaks backwards compatibility**.
"""
PT = PhysicalType
return {
PT.NonPhysical: [None],
# Global
PT.Time: [
spu.picosecond,
femtosecond,
spu.nanosecond,
spu.microsecond,
spu.millisecond,
spu.second,
spu.minute,
spu.hour,
spu.day,
],
PT.Angle: [
spu.radian,
spu.degree,
],
PT.SolidAngle: [
spu.steradian,
],
PT.Freq: (
_valid_freqs := [
terahertz,
spu.hertz,
kilohertz,
megahertz,
gigahertz,
petahertz,
exahertz,
]
),
PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs],
# Cartesian
PT.Length: (
_valid_lens := [
spu.micrometer,
spu.nanometer,
spu.picometer,
spu.angstrom,
spu.millimeter,
spu.centimeter,
spu.meter,
spu.inch,
spu.foot,
spu.yard,
spu.mile,
]
),
PT.Area: [_unit**2 for _unit in _valid_lens],
PT.Volume: [_unit**3 for _unit in _valid_lens],
# Mechanical
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
PT.Mass: [
spu.kilogram,
spu.electron_rest_mass,
spu.dalton,
spu.microgram,
spu.milligram,
spu.gram,
spu.metric_ton,
],
PT.Force: [
micronewton,
nanonewton,
millinewton,
spu.newton,
spu.kg * spu.meter / spu.second**2,
],
PT.Pressure: [
spu.bar,
millibar,
spu.pascal,
hectopascal,
spu.atmosphere,
spu.psi,
spu.mmHg,
spu.torr,
],
# Energy
PT.Work: [
spu.joule,
spu.electronvolt,
],
PT.Power: [
spu.watt,
],
PT.PowerFlux: [
spu.watt / spu.meter**2,
],
PT.Temp: [
spu.kelvin,
],
# Electrodynamics
PT.Current: [
spu.ampere,
],
PT.CurrentDensity: [
spu.ampere / spu.meter**2,
],
PT.Charge: [
spu.coulomb,
],
PT.Voltage: [
spu.volt,
],
PT.Capacitance: [
spu.farad,
],
PT.Impedance: [
spu.ohm,
],
PT.Conductance: [
spu.siemens,
],
PT.Conductivity: [
spu.siemens / spu.micrometer,
spu.siemens / spu.meter,
],
PT.MFlux: [
spu.weber,
],
PT.MFluxDensity: [
spu.tesla,
],
PT.Inductance: [
spu.henry,
],
PT.EField: [
spu.volt / spu.micrometer,
spu.volt / spu.meter,
],
PT.HField: [
spu.ampere / spu.micrometer,
spu.ampere / spu.meter,
],
# Luminal
PT.LumIntensity: [
spu.candela,
],
PT.LumFlux: [
spu.candela * spu.steradian,
],
PT.Illuminance: [
spu.candela / spu.meter**2,
],
}[self]
@staticmethod
def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None:
"""Attempt to determine a matching `PhysicalType` from a unit.
NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`.
Returns:
The matched `PhysicalType`.
If none could be matched, then either return `None` (if `optional` is set) or error.
Raises:
ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
"""
if unit is None:
return ct.PhysicalType.NonPhysical
unit_dim_deps = unit_to_unit_dim_deps(unit)
if unit_dim_deps is not None:
for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps):
return physical_type
if optional:
return None
msg = f'Could not determine PhysicalType for {unit}'
raise ValueError(msg)
@staticmethod
def from_unit_dim(
unit_dim: SympyType | None, optional: bool = False
) -> typ.Self | None:
"""Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`.
For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`).
To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions.
Returns:
The matched `PhysicalType`.
If none could be matched, then either return `None` (if `optional` is set) or error.
Raises:
ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
"""
for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
if compare_unit_dims(unit_dim, candidate_unit_dim):
return physical_type
if optional:
return None
msg = f'Could not determine PhysicalType for {unit_dim}'
raise ValueError(msg)
@functools.cached_property
def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]:
PT = PhysicalType
overrides = {
# Cartesian
PT.Length: [None, (2,), (3,)],
# Mechanical
PT.Vel: [None, (2,), (3,)],
PT.Accel: [None, (2,), (3,)],
PT.Force: [None, (2,), (3,)],
# Energy
PT.Work: [None, (2,), (3,)],
PT.PowerFlux: [None, (2,), (3,)],
# Electrodynamics
PT.CurrentDensity: [None, (2,), (3,)],
PT.MFluxDensity: [None, (2,), (3,)],
PT.EField: [None, (2,), (3,)],
PT.HField: [None, (2,), (3,)],
# Luminal
PT.LumFlux: [None, (2,), (3,)],
}
return overrides.get(self, [None])
@functools.cached_property
def valid_mathtypes(self) -> list[MathType]:
"""Returns a list of valid mathematical types, especially whether it can be real- or complex-valued.
Generally, all unit quantities are real, in the algebraic mathematical sense.
However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening.
This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems.
In general, the value is a phasor.
While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply.
Notes:
- **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation.
- **Current**/**Voltage**: The imaginary part represents the phase.
This also holds for any downstream units.
- **Charge**: Generally, it is real.
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
"""
MT = MathType
PT = PhysicalType
overrides = {
PT.NonPhysical: list(MT), ## Support All
# Cartesian
PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
# Mechanical
# Energy
# Electrodynamics
PT.Current: [MT.Real, MT.Complex], ## Im -> Phase
PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase
PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase
PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase
PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase
PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance
PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction
PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction
PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction
PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase
PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
# Luminal
}
return overrides.get(self, [MT.Real])
@staticmethod
def to_name(value: typ.Self) -> str:
if value is PhysicalType.NonPhysical:
return 'Unitless'
return PhysicalType(value).name
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
PT = PhysicalType
return (
str(self),
PT.to_name(self),
PT.to_name(self),
PT.to_icon(self),
i,
)
####################
# - Standard Unit Systems
####################
UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit]
_PT = PhysicalType
UNITS_SI: UnitSystem = {
_PT.NonPhysical: None,
# Global
_PT.Time: spu.second,
_PT.Angle: spu.radian,
_PT.SolidAngle: spu.steradian,
_PT.Freq: spu.hertz,
_PT.AngFreq: spu.radian * spu.hertz,
# Cartesian
_PT.Length: spu.meter,
_PT.Area: spu.meter**2,
_PT.Volume: spu.meter**3,
# Mechanical
_PT.Vel: spu.meter / spu.second,
_PT.Accel: spu.meter / spu.second**2,
_PT.Mass: spu.kilogram,
_PT.Force: spu.newton,
# Energy
_PT.Work: spu.joule,
_PT.Power: spu.watt,
_PT.PowerFlux: spu.watt / spu.meter**2,
_PT.Temp: spu.kelvin,
# Electrodynamics
_PT.Current: spu.ampere,
_PT.CurrentDensity: spu.ampere / spu.meter**2,
_PT.Voltage: spu.volt,
_PT.Capacitance: spu.farad,
_PT.Impedance: spu.ohm,
_PT.Conductance: spu.siemens,
_PT.Conductivity: spu.siemens / spu.meter,
_PT.MFlux: spu.weber,
_PT.MFluxDensity: spu.tesla,
_PT.Inductance: spu.henry,
_PT.EField: spu.volt / spu.meter,
_PT.HField: spu.ampere / spu.meter,
# Luminal
_PT.LumIntensity: spu.candela,
_PT.LumFlux: lumen,
_PT.Illuminance: spu.lux,
}
####################
# - Sympy Utilities: Cast to Python
####################
def sympy_to_python(
scalar: sp.Basic, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array:
"""Convert a scalar sympy expression to the directly corresponding Python type.
Arguments:
scalar: A sympy expression that has no symbols, but is expressed as a Sympy type.
For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter.
Returns:
A pure Python type that directly corresponds to the input scalar expression.
"""
if isinstance(scalar, sp.MatrixBase):
# Detect Single Column Vector
## --> Flatten to Single Row Vector
if len(scalar.shape) == 2 and scalar.shape[1] == 1:
_scalar = scalar.T
else:
_scalar = scalar
# Convert to Tuple of Tuples
matrix = tuple(
[tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()]
)
# Detect Single Row Vector
## --> This could be because the scalar had it.
## --> This could also be because we flattened a column vector.
## Either way, we should strip the pointless dimensions.
if len(matrix) == 1:
return matrix[0] if not use_jax_array else jnp.array(matrix[0])
return matrix if not use_jax_array else jnp.array(matrix)
if scalar.is_integer:
return int(scalar)
if scalar.is_rational or scalar.is_real:
return float(scalar)
if scalar.is_complex:
return complex(scalar)
msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001
raise ValueError(msg)
####################
# - Convert to Unit System
####################
def strip_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None = None
) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_.
Notes:
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
"""
if unit_system is None:
return sp_obj.subs(UNIT_TO_1)
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array:
"""Convert an expression to the units of a given unit system, then strip all units of the unit system.
Afterwards, it is converted to an appropriate Python type.
Notes:
For stability, and performance, reasons, this should only be used at the very last stage.
Regarding performance: **This is not a fast function**.
Parameters:
sp_obj: An arbitrary sympy object, presumably with units.
unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units.
Note that, in this context, only `unit_system.values()` is used.
Returns:
An appropriate pure Python type, after scaling to the unit system and stripping all units away.
If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`.
"""
return sympy_to_python(
strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system),
use_jax_array=use_jax_array,
)