diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
index 244e54b..b5def9c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
@@ -21,7 +21,7 @@ import typing as typ
import bpy
import sympy as sp
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .socket_types import SocketType
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
index a6c2b54..1599f45 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
@@ -22,7 +22,7 @@ import numpy as np
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py
index 412cd05..c588fab 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py
@@ -16,7 +16,7 @@
import typing as typ
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from . import FlowKind
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
index 4ff917e..d892ece 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
@@ -19,7 +19,7 @@ import functools
import typing as typ
from blender_maxwell.contracts import BLEnumElement
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
index 43dbfe5..58eba20 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
@@ -18,8 +18,8 @@ import dataclasses
import functools
import typing as typ
-from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
from .array import ArrayFlow
from .lazy_range import RangeFlow
@@ -89,6 +89,7 @@ class InfoFlow:
return None
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
+ """Retrieve the dimension associated with a particular index."""
if idx > 0 and idx < len(self.dims) - 1:
return list(self.dims.keys())[idx]
return None
@@ -179,15 +180,23 @@ class InfoFlow:
While that sounds fancy and all, it boils down to:
$$
- \texttt{dims} + |\texttt{output}.\texttt{shape}|
+ |\texttt{dims}| + |\texttt{output}.\texttt{shape}|
$$
- Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
+ Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape.
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
- return len(self.input_mathtypes) + self.output_shape_len
+ return len(self.dims) + self.output_shape_len
+
+ @functools.cached_property
+ def is_scalar(self) -> tuple[spux.MathType, int, int]:
+ """Whether the described object can be described as "scalar".
+
+ True when `self.order == 0`.
+ """
+ return self.order == 0
####################
# - Properties
@@ -204,6 +213,58 @@ class InfoFlow:
for dim, dim_idx in self.dims.items()
}
+ ####################
+ # - Operations: Comparison
+ ####################
+ def compare_dims_identical(self, other: typ.Self) -> bool:
+ """Whether that the quantity and properites of all dimension `SimSymbol`s are "identical".
+
+ "Identical" is defined according to the semantics of `SimSymbol.compare()`, which generally means that everything but the exact name and unit are different.
+ """
+ return len(self.dims) == len(other.dims) and all(
+ dim_l.compare(dim_r)
+ for dim_l, dim_r in zip(self.dims, other.dims, strict=True)
+ )
+
+ def compare_addable(
+ self, other: typ.Self, allow_differing_unit: bool = False
+ ) -> bool:
+ """Whether the two `InfoFlows` can be added/subtracted elementwise.
+
+ Parameters:
+ allow_differing_unit: When set,
+ Forces the user to be explicit about specifying
+ """
+ return self.compare_dims_identical(other) and self.output.compare_addable(
+ other.output, allow_differing_unit=allow_differing_unit
+ )
+
+ def compare_multiplicable(self, other: typ.Self) -> bool:
+ """Whether the two `InfoFlow`s can be multiplied (elementwise).
+
+ - The output `SimSymbol`s must be able to be multiplied.
+ - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical.
+ """
+ return self.output.compare_multiplicable(other.output) and (
+ (len(self.dims) == 0 and self.output.shape_len == 0)
+ or (len(other.dims) == 0 and other.output.shape_len == 0)
+ or self.compare_dims_identical(other)
+ )
+
+ def compare_exponentiable(self, other: typ.Self) -> bool:
+ """Whether the two `InfoFlow`s can be exponentiated.
+
+ In general, we follow the rules of the "Hadamard Power" operator, which is also in use in `numpy` broadcasting rules.
+
+ - The output `SimSymbol`s must be able to be exponentiated (mainly, the exponent can't have a unit).
+ - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical.
+ """
+ return self.output.compare_exponentiable(other.output) and (
+ (len(self.dims) == 0 and self.output.shape_len == 0)
+ or (len(other.dims) == 0 and other.output.shape_len == 0)
+ or self.compare_dims_identical(other)
+ )
+
####################
# - Operations: Dimensions
####################
@@ -319,6 +380,7 @@ class InfoFlow:
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
) -> spux.SympyExpr:
+ """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
sym_name = sim_symbols.SimSymbolName.Expr
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
@@ -341,11 +403,11 @@ class InfoFlow:
cols = self.output.cols
match (rows, cols):
case (1, 1):
- new_output = self.output.set_size(len(last_idx), 1)
+ new_output = self.output.update(rows=len(last_idx), cols=1)
case (_, 1):
- new_output = self.output.set_size(rows, len(last_idx))
+ new_output = self.output.update(rows=rows, cols=len(last_idx))
case (1, _):
- new_output = self.output.set_size(len(last_idx), cols)
+ new_output = self.output.update(rows=len(last_idx), cols=cols)
case (_, _):
raise NotImplementedError ## Not yet :)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
index e881c14..025ef71 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
@@ -226,7 +226,7 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
@@ -335,7 +335,7 @@ class FuncFlow(pyd.BaseModel):
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
"""
- if self.supports_jax:
+ if self.supports_jax and not disallow_jax:
return self.func_jax(
*params.scaled_func_args(symbol_values),
**params.scaled_func_kwargs(symbol_values),
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
index a204eef..a7fd668 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
@@ -24,8 +24,8 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
from .array import ArrayFlow
@@ -95,10 +95,12 @@ class RangeFlow(pyd.BaseModel):
stop: spux.ScalarUnitlessRealExpr
steps: int = 0
scaling: ScalingMode = ScalingMode.Lin
+ ## TODO: No support for non-Lin (yet)
unit: spux.Unit | None = None
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
+ ## TODO: No proper support for symbols (yet)
# Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
@@ -112,18 +114,26 @@ class RangeFlow(pyd.BaseModel):
steps: int = 50,
scaling: ScalingMode | str = ScalingMode.Lin,
) -> typ.Self:
- if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
- use_steps = 0
- else:
- use_steps = steps
+ if (
+ sym.mathtype is not spux.MathType.Complex
+ and sym.rows == 1
+ and sym.cols == 1
+ ):
+ if sym.domain.inf.is_infinite or sym.domain.sup.is_infinite:
+ _steps = 0
+ else:
+ _steps = steps
- return RangeFlow(
- start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
- stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
- steps=use_steps,
- scaling=ScalingMode(scaling),
- unit=sym.unit,
- )
+ return RangeFlow(
+ start=sym.domain.inf if sym.domain.inf.is_finite else sp.S(-1),
+ stop=sym.domain.sup if sym.domain.sup.is_finite else sp.S(1),
+ steps=_steps,
+ scaling=ScalingMode(scaling),
+ unit=sym.unit,
+ )
+
+ msg = f'RangeFlow is incompatible with SimSymbol {sym}'
+ raise ValueError(msg)
def to_sym(
self,
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
index 20d6562..b7a9cb7 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
@@ -23,7 +23,7 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
@@ -53,11 +53,6 @@ class ParamsFlow(pyd.BaseModel):
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = pyd.Field(default_factory=dict)
- @functools.cached_property
- def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
- """Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
- return {sym for sym in self.symbols if sym.can_diff}
-
####################
# - Symbols
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
index 16eeca5..9765991 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
@@ -29,7 +29,7 @@ import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
index 0dc4752..ce6259b 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
@@ -28,7 +28,7 @@ import typing as typ
import sympy.physics.units as spu
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
####################
# - Unit Systems
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
index 4f16660..4242efd 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
@@ -20,7 +20,7 @@ import typing as typ
import bpy
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import bl_socket_map
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py
new file mode 100644
index 0000000..04c6341
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py
@@ -0,0 +1,29 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .filter import FilterOperation
+from .map import MapOperation
+from .operate import BinaryOperation
+from .reduce import ReduceOperation
+from .transform import TransformOperation
+
+__all__ = [
+ 'FilterOperation',
+ 'MapOperation',
+ 'BinaryOperation',
+ 'ReduceOperation',
+ 'TransformOperation',
+]
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py
new file mode 100644
index 0000000..6bea59f
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py
@@ -0,0 +1,233 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import jax.lax as jlax
+import jax.numpy as jnp
+
+from blender_maxwell.utils import logger, sim_symbols
+
+from .. import contracts as ct
+
+log = logger.get(__name__)
+
+
+class FilterOperation(enum.StrEnum):
+ """Valid operations for the `FilterMathNode`.
+
+ Attributes:
+ DimToVec: Shift last dimension to output.
+ DimsToMat: Shift last 2 dimensions to output.
+ PinLen1: Remove a len(1) dimension.
+ Pin: Remove a len(n) dimension by selecting a particular index.
+ Swap: Swap the positions of two dimensions.
+ """
+
+ # Slice
+ Slice = enum.auto()
+ SliceIdx = enum.auto()
+
+ # Pin
+ PinLen1 = enum.auto()
+ Pin = enum.auto()
+ PinIdx = enum.auto()
+
+ # Dimension
+ Swap = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ FO = FilterOperation
+ return {
+ # Slice
+ FO.Slice: '≈a[v₁:v₂]',
+ FO.SliceIdx: '=a[i:j]',
+ # Pin
+ FO.PinLen1: 'a[0] → a',
+ FO.Pin: 'a[v] ⇝ a',
+ FO.PinIdx: 'a[i] → a',
+ # Reinterpret
+ FO.Swap: 'a₁ ↔ a₂',
+ }[value]
+
+ @staticmethod
+ def to_icon(value: typ.Self) -> str:
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ FO = FilterOperation
+ return (
+ str(self),
+ FO.to_name(self),
+ FO.to_name(self),
+ FO.to_icon(self),
+ i,
+ )
+
+ ####################
+ # - Ops from Info
+ ####################
+ @staticmethod
+ def by_info(info: ct.InfoFlow) -> list[typ.Self]:
+ FO = FilterOperation
+ operations = []
+
+ # Slice
+ if info.dims:
+ operations.append(FO.SliceIdx)
+
+ # Pin
+ ## PinLen1
+ ## -> There must be a dimension with length 1.
+ if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
+ operations.append(FO.PinLen1)
+
+ ## Pin | PinIdx
+ ## -> There must be a dimension, full stop.
+ if info.dims:
+ operations += [FO.Pin, FO.PinIdx]
+
+ # Reinterpret
+ ## Swap
+ ## -> There must be at least two dimensions.
+ if len(info.dims) >= 2: # noqa: PLR2004
+ operations.append(FO.Swap)
+
+ return operations
+
+ ####################
+ # - Computed Properties
+ ####################
+ @property
+ def func_args(self) -> list[sim_symbols.SimSymbol]:
+ FO = FilterOperation
+ return {
+ # Pin
+ FO.Pin: [sim_symbols.idx(None)],
+ FO.PinIdx: [sim_symbols.idx(None)],
+ }.get(self, [])
+
+ ####################
+ # - Methods
+ ####################
+ @property
+ def num_dim_inputs(self) -> None:
+ FO = FilterOperation
+ return {
+ # Slice
+ FO.Slice: 1,
+ FO.SliceIdx: 1,
+ # Pin
+ FO.PinLen1: 1,
+ FO.Pin: 1,
+ FO.PinIdx: 1,
+ # Reinterpret
+ FO.Swap: 2,
+ }[self]
+
+ def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
+ FO = FilterOperation
+ match self:
+ # Slice
+ case FO.Slice:
+ return [dim for dim in info.dims if not info.has_idx_labels(dim)]
+
+ case FO.SliceIdx:
+ return [dim for dim in info.dims if not info.has_idx_labels(dim)]
+
+ # Pin
+ case FO.PinLen1:
+ return [
+ dim
+ for dim, dim_idx in info.dims.items()
+ if not info.has_idx_cont(dim) and len(dim_idx) == 1
+ ]
+
+ case FO.Pin:
+ return info.dims
+
+ case FO.PinIdx:
+ return [dim for dim in info.dims if not info.has_idx_cont(dim)]
+
+ # Dimension
+ case FO.Swap:
+ return info.dims
+
+ return []
+
+ def are_dims_valid(
+ self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
+ ) -> bool:
+ """Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
+ if self.num_dim_inputs == 1:
+ return dim_0 in self.valid_dims(info)
+
+ if self.num_dim_inputs == 2: # noqa: PLR2004
+ valid_dims = self.valid_dims(info)
+ return dim_0 in valid_dims and dim_1 in valid_dims
+
+ return False
+
+ ####################
+ # - UI
+ ####################
+ def jax_func(
+ self,
+ axis_0: int | None,
+ axis_1: int | None,
+ slice_tuple: tuple[int, int, int] | None = None,
+ ):
+ FO = FilterOperation
+ return {
+ # Pin
+ FO.Slice: lambda expr: jlax.slice_in_dim(
+ expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
+ ),
+ FO.SliceIdx: lambda expr: jlax.slice_in_dim(
+ expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
+ ),
+ # Pin
+ FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
+ FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
+ FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
+ # Dimension
+ FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
+ }[self]
+
+ def transform_info(
+ self,
+ info: ct.InfoFlow,
+ dim_0: sim_symbols.SimSymbol,
+ dim_1: sim_symbols.SimSymbol,
+ pin_idx: int | None = None,
+ slice_tuple: tuple[int, int, int] | None = None,
+ ):
+ FO = FilterOperation
+ return {
+ FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
+ FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
+ # Pin
+ FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
+ FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
+ FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
+ # Reinterpret
+ FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
+ }[self]()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py
new file mode 100644
index 0000000..d708561
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py
@@ -0,0 +1,365 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import jax.numpy as jnp
+import sympy as sp
+
+from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+
+from .. import contracts as ct
+
+log = logger.get(__name__)
+
+
+class MapOperation(enum.StrEnum):
+ """Valid operations for the `MapMathNode`.
+
+ Attributes:
+ Real: Compute the real part of the input.
+ Imag: Compute the imaginary part of the input.
+ Abs: Compute the absolute value of the input.
+ Sq: Square the input.
+ Sqrt: Compute the (principal) square root of the input.
+ InvSqrt: Compute the inverse square root of the input.
+ Cos: Compute the cosine of the input.
+ Sin: Compute the sine of the input.
+ Tan: Compute the tangent of the input.
+ Acos: Compute the inverse cosine of the input.
+ Asin: Compute the inverse sine of the input.
+ Atan: Compute the inverse tangent of the input.
+ Norm2: Compute the 2-norm (aka. length) of the input vector.
+ Det: Compute the determinant of the input matrix.
+ Cond: Compute the condition number of the input matrix.
+ NormFro: Compute the frobenius norm of the input matrix.
+ Rank: Compute the rank of the input matrix.
+ Diag: Compute the diagonal vector of the input matrix.
+ EigVals: Compute the eigenvalues vector of the input matrix.
+ SvdVals: Compute the singular values vector of the input matrix.
+ Inv: Compute the inverse matrix of the input matrix.
+ Tra: Compute the transpose matrix of the input matrix.
+ Qr: Compute the QR-factorized matrices of the input matrix.
+ Chol: Compute the Cholesky-factorized matrices of the input matrix.
+ Svd: Compute the SVD-factorized matrices of the input matrix.
+ """
+
+ # By Number
+ Real = enum.auto()
+ Imag = enum.auto()
+ Abs = enum.auto()
+ Sq = enum.auto()
+ Sqrt = enum.auto()
+ InvSqrt = enum.auto()
+ Cos = enum.auto()
+ Sin = enum.auto()
+ Tan = enum.auto()
+ Acos = enum.auto()
+ Asin = enum.auto()
+ Atan = enum.auto()
+ Sinc = enum.auto()
+ # By Vector
+ Norm2 = enum.auto()
+ # By Matrix
+ Det = enum.auto()
+ Cond = enum.auto()
+ NormFro = enum.auto()
+ Rank = enum.auto()
+ Diag = enum.auto()
+ EigVals = enum.auto()
+ SvdVals = enum.auto()
+ Inv = enum.auto()
+ Tra = enum.auto()
+ Qr = enum.auto()
+ Chol = enum.auto()
+ Svd = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name for a physical type."""
+ MO = MapOperation
+ return {
+ # By Number
+ MO.Real: 'ℝ(v)',
+ MO.Imag: 'Im(v)',
+ MO.Abs: '|v|',
+ MO.Sq: 'v²',
+ MO.Sqrt: '√v',
+ MO.InvSqrt: '1/√v',
+ MO.Cos: 'cos v',
+ MO.Sin: 'sin v',
+ MO.Tan: 'tan v',
+ MO.Acos: 'acos v',
+ MO.Asin: 'asin v',
+ MO.Atan: 'atan v',
+ MO.Sinc: 'sinc v',
+ # By Vector
+ MO.Norm2: '||v||₂',
+ # By Matrix
+ MO.Det: 'det V',
+ MO.Cond: 'κ(V)',
+ MO.NormFro: '||V||_F',
+ MO.Rank: 'rank V',
+ MO.Diag: 'diag V',
+ MO.EigVals: 'eigvals V',
+ MO.SvdVals: 'svdvals V',
+ MO.Inv: 'V⁻¹',
+ MO.Tra: 'Vt',
+ MO.Qr: 'qr V',
+ MO.Chol: 'chol V',
+ MO.Svd: 'svd V',
+ }[value]
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ """No icons."""
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
+ MO = MapOperation
+ return (
+ str(self),
+ MO.to_name(self),
+ MO.to_name(self),
+ MO.to_icon(self),
+ i,
+ )
+
+ ####################
+ # - Ops from Shape
+ ####################
+ @staticmethod
+ def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
+ ## TODO: By info, not shape.
+ ## TODO: Check valid domains/mathtypes for some functions.
+ MO = MapOperation
+ element_ops = [
+ MO.Real,
+ MO.Imag,
+ MO.Abs,
+ MO.Sq,
+ MO.Sqrt,
+ MO.InvSqrt,
+ MO.Cos,
+ MO.Sin,
+ MO.Tan,
+ MO.Acos,
+ MO.Asin,
+ MO.Atan,
+ MO.Sinc,
+ ]
+
+ match (info.output.rows, info.output.cols):
+ case (1, 1):
+ return element_ops
+
+ case (_, 1):
+ return [*element_ops, MO.Norm2]
+
+ case (rows, cols) if rows == cols:
+ ## TODO: Check hermitian/posdef for cholesky.
+ ## - Can we even do this with just the output symbol approach?
+ return [
+ *element_ops,
+ MO.Det,
+ MO.Cond,
+ MO.NormFro,
+ MO.Rank,
+ MO.Diag,
+ MO.EigVals,
+ MO.SvdVals,
+ MO.Inv,
+ MO.Tra,
+ MO.Qr,
+ MO.Chol,
+ MO.Svd,
+ ]
+
+ case (rows, cols):
+ return [
+ *element_ops,
+ MO.Cond,
+ MO.NormFro,
+ MO.Rank,
+ MO.SvdVals,
+ MO.Inv,
+ MO.Tra,
+ MO.Svd,
+ ]
+
+ return []
+
+ ####################
+ # - Function Properties
+ ####################
+ @property
+ def sp_func(self):
+ MO = MapOperation
+ return {
+ # By Number
+ MO.Real: lambda expr: sp.re(expr),
+ MO.Imag: lambda expr: sp.im(expr),
+ MO.Abs: lambda expr: sp.Abs(expr),
+ MO.Sq: lambda expr: expr**2,
+ MO.Sqrt: lambda expr: sp.sqrt(expr),
+ MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
+ MO.Cos: lambda expr: sp.cos(expr),
+ MO.Sin: lambda expr: sp.sin(expr),
+ MO.Tan: lambda expr: sp.tan(expr),
+ MO.Acos: lambda expr: sp.acos(expr),
+ MO.Asin: lambda expr: sp.asin(expr),
+ MO.Atan: lambda expr: sp.atan(expr),
+ MO.Sinc: lambda expr: sp.sinc(expr),
+ # By Vector
+ # Vector -> Number
+ MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0],
+ # By Matrix
+ # Matrix -> Number
+ MO.Det: lambda expr: sp.det(expr),
+ MO.Cond: lambda expr: expr.condition_number(),
+ MO.NormFro: lambda expr: expr.norm(ord='fro'),
+ MO.Rank: lambda expr: expr.rank(),
+ # Matrix -> Vec
+ MO.Diag: lambda expr: expr.diagonal(),
+ MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
+ MO.SvdVals: lambda expr: expr.singular_values(),
+ # Matrix -> Matrix
+ MO.Inv: lambda expr: expr.inv(),
+ MO.Tra: lambda expr: expr.T,
+ # Matrix -> Matrices
+ MO.Qr: lambda expr: expr.QRdecomposition(),
+ MO.Chol: lambda expr: expr.cholesky(),
+ MO.Svd: lambda expr: expr.singular_value_decomposition(),
+ }[self]
+
+ @property
+ def jax_func(self):
+ MO = MapOperation
+ return {
+ # By Number
+ MO.Real: lambda expr: jnp.real(expr),
+ MO.Imag: lambda expr: jnp.imag(expr),
+ MO.Abs: lambda expr: jnp.abs(expr),
+ MO.Sq: lambda expr: jnp.square(expr),
+ MO.Sqrt: lambda expr: jnp.sqrt(expr),
+ MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
+ MO.Cos: lambda expr: jnp.cos(expr),
+ MO.Sin: lambda expr: jnp.sin(expr),
+ MO.Tan: lambda expr: jnp.tan(expr),
+ MO.Acos: lambda expr: jnp.acos(expr),
+ MO.Asin: lambda expr: jnp.asin(expr),
+ MO.Atan: lambda expr: jnp.atan(expr),
+ MO.Sinc: lambda expr: jnp.sinc(expr),
+ # By Vector
+ # Vector -> Number
+ MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
+ # By Matrix
+ # Matrix -> Number
+ MO.Det: lambda expr: jnp.linalg.det(expr),
+ MO.Cond: lambda expr: jnp.linalg.cond(expr),
+ MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
+ MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
+ # Matrix -> Vec
+ MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
+ MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
+ MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
+ # Matrix -> Matrix
+ MO.Inv: lambda expr: jnp.linalg.inv(expr),
+ MO.Tra: lambda expr: jnp.matrix_transpose(expr),
+ # Matrix -> Matrices
+ MO.Qr: lambda expr: jnp.linalg.qr(expr),
+ MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
+ MO.Svd: lambda expr: jnp.linalg.svd(expr),
+ }[self]
+
+ def transform_info(self, info: ct.InfoFlow):
+ MO = MapOperation
+
+ return {
+ # By Number
+ MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
+ MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
+ MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
+ MO.Sq: lambda: info,
+ MO.Sqrt: lambda: info,
+ MO.InvSqrt: lambda: info,
+ MO.Cos: lambda: info,
+ MO.Sin: lambda: info,
+ MO.Tan: lambda: info,
+ MO.Acos: lambda: info,
+ MO.Asin: lambda: info,
+ MO.Atan: lambda: info,
+ MO.Sinc: lambda: info,
+ # By Vector
+ MO.Norm2: lambda: info.update_output(
+ mathtype=spux.MathType.Real,
+ rows=1,
+ cols=1,
+ # Interval
+ interval_finite_re=(0, sim_symbols.float_max),
+ interval_inf=(False, True),
+ interval_closed=(True, False),
+ ),
+ # By Matrix
+ MO.Det: lambda: info.update_output(
+ rows=1,
+ cols=1,
+ ),
+ MO.Cond: lambda: info.update_output(
+ mathtype=spux.MathType.Real,
+ rows=1,
+ cols=1,
+ physical_type=spux.PhysicalType.NonPhysical,
+ unit=None,
+ ),
+ MO.NormFro: lambda: info.update_output(
+ mathtype=spux.MathType.Real,
+ rows=1,
+ cols=1,
+ # Interval
+ interval_finite_re=(0, sim_symbols.float_max),
+ interval_inf=(False, True),
+ interval_closed=(True, False),
+ ),
+ MO.Rank: lambda: info.update_output(
+ mathtype=spux.MathType.Integer,
+ rows=1,
+ cols=1,
+ physical_type=spux.PhysicalType.NonPhysical,
+ unit=None,
+ # Interval
+ interval_finite_re=(0, sim_symbols.int_max),
+ interval_inf=(False, True),
+ interval_closed=(True, False),
+ ),
+ # Matrix -> Vector ## TODO: ALL OF THESE
+ MO.Diag: lambda: info,
+ MO.EigVals: lambda: info,
+ MO.SvdVals: lambda: info,
+ # Matrix -> Matrix ## TODO: ALL OF THESE
+ MO.Inv: lambda: info,
+ MO.Tra: lambda: info,
+ # Matrix -> Matrices ## TODO: ALL OF THESE
+ MO.Qr: lambda: info,
+ MO.Chol: lambda: info,
+ MO.Svd: lambda: info,
+ }[self]()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py
new file mode 100644
index 0000000..cf3d228
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py
@@ -0,0 +1,476 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import jax.numpy as jnp
+import sympy as sp
+import sympy.physics.quantum as spq
+import sympy.physics.units as spu
+
+from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+
+from .. import contracts as ct
+
+log = logger.get(__name__)
+
+
+def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
+ """Implement the Hadamard Power.
+
+ Follows the specification in , which also conforms to `numpy` broadcasting rules for `**` on `np.ndarray`.
+ """
+ match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
+ case (False, False):
+ msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}"
+ raise ValueError(msg)
+
+ case (True, False):
+ return lhs.applyfunc(lambda el: el**rhs)
+
+ case (False, True):
+ return rhs.applyfunc(lambda el: lhs**el)
+
+ case (True, True) if lhs.shape == rhs.shape:
+ common_shape = lhs.shape
+ return sp.ImmutableMatrix(
+ *common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
+ )
+
+ case _:
+ msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
+ raise ValueError(msg)
+
+
+class BinaryOperation(enum.StrEnum):
+ """Valid operations for the `OperateMathNode`.
+
+ Attributes:
+ Mul: Scalar multiplication.
+ Div: Scalar division.
+ Pow: Scalar exponentiation.
+ Add: Elementwise addition.
+ Sub: Elementwise subtraction.
+ HadamMul: Elementwise multiplication (hadamard product).
+ HadamPow: Principled shape-aware exponentiation (hadamard power).
+ Atan2: Quadrant-respecting 2D arctangent.
+ VecVecDot: Dot product for identically shaped vectors w/transpose.
+ Cross: Cross product between identically shaped 3D vectors.
+ VecVecOuter: Vector-vector outer product.
+ LinSolve: Solve a linear system.
+ LsqSolve: Minimize error of an underdetermined linear system.
+ VecMatOuter: Vector-matrix outer product.
+ MatMatDot: Matrix-matrix dot product.
+ """
+
+ # Number | Number
+ Mul = enum.auto()
+ Div = enum.auto()
+ Pow = enum.auto()
+
+ # Elements | Elements
+ Add = enum.auto()
+ Sub = enum.auto()
+ HadamMul = enum.auto()
+ HadamPow = enum.auto()
+ HadamDiv = enum.auto()
+ Atan2 = enum.auto()
+
+ # Vector | Vector
+ VecVecDot = enum.auto()
+ Cross = enum.auto()
+ VecVecOuter = enum.auto()
+
+ # Matrix | Vector
+ LinSolve = enum.auto()
+ LsqSolve = enum.auto()
+
+ # Vector | Matrix
+ VecMatOuter = enum.auto()
+
+ # Matrix | Matrix
+ MatMatDot = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name for a physical type."""
+ BO = BinaryOperation
+ return {
+ # Number | Number
+ BO.Mul: 'ℓ · r',
+ BO.Div: 'ℓ / r',
+ BO.Pow: 'ℓ ^ r', ## Also for square-matrix powers.
+ # Elements | Elements
+ BO.Add: 'ℓ + r',
+ BO.Sub: 'ℓ - r',
+ BO.HadamMul: '𝐋 ⊙ 𝐑',
+ BO.HadamDiv: '𝐋 ⊙/ 𝐑',
+ BO.HadamPow: '𝐥 ⊙^ 𝐫',
+ BO.Atan2: 'atan2(ℓ:x, r:y)',
+ # Vector | Vector
+ BO.VecVecDot: '𝐥 · 𝐫',
+ BO.Cross: 'cross(𝐥,𝐫)',
+ BO.VecVecOuter: '𝐥 ⊗ 𝐫',
+ # Matrix | Vector
+ BO.LinSolve: '𝐋 ∖ 𝐫',
+ BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
+ # Vector | Matrix
+ BO.VecMatOuter: '𝐋 ⊗ 𝐫',
+ # Matrix | Matrix
+ BO.MatMatDot: '𝐋 · 𝐑',
+ }[value]
+
+ @staticmethod
+ def to_icon(value: typ.Self) -> str:
+ """No icons."""
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
+ BO = BinaryOperation
+ return (
+ str(self),
+ BO.to_name(self),
+ BO.to_name(self),
+ BO.to_icon(self),
+ i,
+ )
+
+ def bl_enum_elements(
+ self, info_l: ct.InfoFlow, info_r: ct.InfoFlow
+ ) -> list[ct.BLEnumElement]:
+ """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
+
+ Returns a `bpy.props.EnumProperty.items`-compatible list.
+ """
+ return [
+ operation.bl_enum_element(i)
+ for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r))
+ ]
+
+ ####################
+ # - Ops from Shape
+ ####################
+ @staticmethod
+ def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]:
+ """Deduce valid binary operations from the shapes of the inputs."""
+ BO = BinaryOperation
+ ops = []
+
+ # Add/Sub
+ if info_l.compare_addable(info_r, allow_differing_unit=True):
+ ops += [BO.Add, BO.Sub]
+
+ # Mul/Div
+ ## -> Mul is ambiguous; we differentiate Hadamard and Standard.
+ ## -> Div additionally requires non-zero guarantees.
+ if info_l.compare_multiplicable(info_r):
+ match (info_l.order, info_r.order, info_r.output.is_nonzero):
+ case (ordl, ordr, True) if ordl == 0 and ordr == 0:
+ ops += [BO.Mul, BO.Div]
+ case (ordl, ordr, True) if ordl > 0 and ordr == 0:
+ ops += [BO.Mul, BO.Div]
+ case (ordl, ordr, True) if ordl == 0 and ordr > 0:
+ ops += [BO.Mul]
+ case (ordl, ordr, True) if ordl > 0 and ordr > 0:
+ ops += [BO.HadamMul, BO.HadamDiv]
+
+ case (ordl, ordr, False) if ordl == 0 and ordr == 0:
+ ops += [BO.Mul]
+ case (ordl, ordr, False) if ordl > 0 and ordr == 0:
+ ops += [BO.Mul]
+ case (ordl, ordr, True) if ordl == 0 and ordr > 0:
+ ops += [BO.Mul]
+ case (ordl, ordr, False) if ordl > 0 and ordr > 0:
+ ops += [BO.HadamMul]
+
+ # Pow
+ ## -> We distinguish between "Hadamard Power" and "Power".
+ ## -> For scalars, they are the same (but we only expose "power").
+ ## -> For matrices, square matrices can be exp'ed by int powers.
+ ## -> Any other combination is well-defined by the Hadamard Power.
+ if info_l.compare_exponentiable(info_r):
+ match (info_l.order, info_r.order, info_r.output.mathtype):
+ case (ordl, ordr, _) if ordl == 0 and ordr == 0:
+ ops += [BO.Pow]
+
+ case (ordl, ordr, spux.MathType.Integer) if (
+ ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols
+ ):
+ ops += [BO.Pow, BO.HadamPow]
+
+ case _:
+ ops += [BO.HadamPow]
+
+ # Operations by-Output Length
+ match (
+ info_l.output.shape_len,
+ info_r.output.shape_len,
+ ):
+ # Number | Number
+ case (0, 0) if info_l.is_scalar and info_r.is_scalar:
+ # atan2: PhysicalType Must Both be Length | NonPhysical
+ ## -> atan2() produces radians from Cartesian coordinates.
+ ## -> This wouldn't make sense on non-Length / non-Unitless.
+ if (
+ info_l.output.physical_type is spux.PhysicalType.Length
+ and info_r.output.physical_type is spux.PhysicalType.Length
+ ) or (
+ info_l.output.physical_type is spux.PhysicalType.NonPhysical
+ and info_l.output.unit is None
+ and info_r.output.physical_type is spux.PhysicalType.NonPhysical
+ and info_r.output.unit is None
+ ):
+ ops += [BO.Atan2]
+
+ return ops
+
+ # Vector | Vector
+ case (1, 1) if info_l.compare_dims_identical(info_r):
+ outl = info_l.output
+ outr = info_r.output
+
+ # 1D Orders: Outer Product is Valid
+ ## -> We can't do per-element outer product.
+ ## -> However, it's still super useful on its own.
+ if info_l.order == 1 and info_r.order == 1:
+ ops += [BO.VecVecOuter]
+
+ # Vector | Vector
+ if outl.rows > outl.cols and outr.rows > outr.cols:
+ ops += [BO.VecVecDot]
+
+ # Covector | Vector
+ if outl.rows < outl.cols and outr.rows > outr.cols:
+ ops += [BO.MatMatDot]
+
+ # Vector | Covector
+ if outl.rows > outl.cols and outr.rows < outr.cols:
+ ops += [BO.MatMatDot]
+
+ # Covector | Covector
+ if outl.rows < outl.cols and outr.rows < outr.cols:
+ ops += [BO.VecVecDot]
+
+ # Cross Product
+ ## -> Works great element-wise.
+ ## -> Enforce that both are 3x1 or 1x3.
+ ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
+ if (outl.rows == 3 and outr.rows == 3) or (
+ outl.cols == 3 and outl.cols == 3
+ ):
+ ops += [BO.Cross]
+
+ # Vector | Matrix
+ ## -> We can't do per-element outer product.
+ ## -> However, it's still super useful on its own.
+ case (1, 2) if info_l.compare_dims_identical(
+ info_r
+ ) and info_l.order == 1 and info_r.order == 2:
+ ops += [BO.VecMatOuter]
+
+ # Matrix | Vector
+ case (2, 1) if info_l.compare_dims_identical(info_r):
+ # Mat-Vec Dot: Enforce RHS Column Vector
+ if outr.rows > outl.cols:
+ ops += [BO.MatMatDot]
+
+ ops += [BO.LinSolve, BO.LsqSolve]
+
+ ## Matrix | Matrix
+ case (2, 2):
+ ops += [BO.MatMatDot]
+
+ return ops
+
+ ####################
+ # - Function Properties
+ ####################
+ @property
+ def sp_func(self):
+ """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
+ BO = BinaryOperation
+
+ ## TODO: Make this compatible with sp.Matrix inputs
+ return {
+ # Number | Number
+ BO.Mul: lambda exprs: exprs[0] * exprs[1],
+ BO.Div: lambda exprs: exprs[0] / exprs[1],
+ BO.Pow: lambda exprs: exprs[0] ** exprs[1],
+ # Elements | Elements
+ BO.Add: lambda exprs: exprs[0] + exprs[1],
+ BO.Sub: lambda exprs: exprs[0] - exprs[1],
+ BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
+ BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
+ BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
+ # Vector | Vector
+ BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
+ BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
+ BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
+ # Matrix | Vector
+ BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
+ BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
+ # Vector | Matrix
+ BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
+ # Matrix | Matrix
+ BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
+ }[self]
+
+ @property
+ def unit_func(self):
+ """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
+ BO = BinaryOperation
+
+ ## TODO: Make this compatible with sp.Matrix inputs
+ return {
+ # Number | Number
+ BO.Mul: BO.Mul.sp_func,
+ BO.Div: BO.Div.sp_func,
+ BO.Pow: BO.Pow.sp_func,
+ # Elements | Elements
+ BO.Add: BO.Add.sp_func,
+ BO.Sub: BO.Sub.sp_func,
+ BO.HadamMul: BO.Mul.sp_func,
+ # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
+ BO.Atan2: lambda _: spu.radian,
+ # Vector | Vector
+ BO.VecVecDot: BO.Mul.sp_func,
+ BO.Cross: BO.Mul.sp_func,
+ BO.VecVecOuter: BO.Mul.sp_func,
+ # Matrix | Vector
+ ## -> A,b in Ax = b have units, and the equality must hold.
+ ## -> Therefore, A \ b must have the units [b]/[A].
+ BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
+ BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
+ # Vector | Matrix
+ BO.VecMatOuter: BO.Mul.sp_func,
+ # Matrix | Matrix
+ BO.MatMatDot: BO.Mul.sp_func,
+ }[self]
+
+ @property
+ def jax_func(self):
+ """Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
+ ## TODO: Scale the units of one side to the other.
+ BO = BinaryOperation
+
+ return {
+ # Number | Number
+ BO.Mul: lambda exprs: exprs[0] * exprs[1],
+ BO.Div: lambda exprs: exprs[0] / exprs[1],
+ BO.Pow: lambda exprs: exprs[0] ** exprs[1],
+ # Elements | Elements
+ BO.Add: lambda exprs: exprs[0] + exprs[1],
+ BO.Sub: lambda exprs: exprs[0] - exprs[1],
+ BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
+ BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
+ exprs[1].applyfunc(lambda el: 1 / el)
+ ),
+ BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
+ BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
+ # Vector | Vector
+ BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]),
+ BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
+ BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
+ # Matrix | Vector
+ BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
+ BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
+ # Vector | Matrix
+ BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
+ # Matrix | Matrix
+ BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
+ }[self]
+
+ ####################
+ # - Transforms
+ ####################
+ def transform_funcs(self, func_l: ct.FuncFlow, func_r: ct.FuncFlow) -> ct.FuncFlow:
+ """Transform two input functions according to the current operation."""
+ BO = BinaryOperation
+
+ # Add/Sub: Normalize Unit of RHS to LHS
+ ## -> We can only add/sub identical units.
+ ## -> To be nice, we only require identical PhysicalType.
+ ## -> The result of a binary operation should have one unit.
+ if self is BO.Add or self is BO.Sub:
+ norm_func_r = func_r.scale_to_unit(func_l.func_output.unit)
+ else:
+ norm_func_r = func_r
+
+ return (func_l, norm_func_r).compose_within(
+ self.jax_func,
+ enclosing_func_output=self.transform_outputs(
+ func_l.func_output, norm_func_r.func_output
+ ),
+ supports_jax=True,
+ )
+
+ def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
+ """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
+
+ Warnings:
+ `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
+
+ If not, bad things will happen.
+ """
+ return info_l.operate_output(
+ info_r,
+ lambda a, b: self.sp_func([a, b]),
+ lambda a, b: self.unit_func([a, b]),
+ )
+
+ ####################
+ # - InfoFlow Transform
+ ####################
+ def transform_outputs(
+ self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol
+ ) -> sim_symbols.SimSymbol:
+ # TO = TransformOperation
+ return None
+ # match self:
+ # # Number | Number
+ # case TO.Mul:
+ # return
+ # case TO.Div:
+ # case TO.Pow:
+
+ # # Elements | Elements
+ # Add = enum.auto()
+ # Sub = enum.auto()
+ # HadamMul = enum.auto()
+ # HadamPow = enum.auto()
+ # HadamDiv = enum.auto()
+ # Atan2 = enum.auto()
+
+ # # Vector | Vector
+ # VecVecDot = enum.auto()
+ # Cross = enum.auto()
+ # VecVecOuter = enum.auto()
+
+ # # Matrix | Vector
+ # LinSolve = enum.auto()
+ # LsqSolve = enum.auto()
+
+ # # Vector | Matrix
+ # VecMatOuter = enum.auto()
+
+ # # Matrix | Matrix
+ # MatMatDot = enum.auto()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py
new file mode 100644
index 0000000..4d26985
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py
@@ -0,0 +1,116 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import jax.numpy as jnp
+import sympy as sp
+
+from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+
+from .. import contracts as ct
+
+log = logger.get(__name__)
+
+
+class ReduceOperation(enum.StrEnum):
+ # Summary
+ Count = enum.auto()
+
+ # Statistics
+ Mean = enum.auto()
+ Std = enum.auto()
+ Var = enum.auto()
+
+ StdErr = enum.auto()
+
+ Min = enum.auto()
+ Q25 = enum.auto()
+ Median = enum.auto()
+ Q75 = enum.auto()
+ Max = enum.auto()
+
+ Mode = enum.auto()
+
+ # Reductions
+ Sum = enum.auto()
+ Prod = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name for a physical type."""
+ RO = ReduceOperation
+ return {
+ # Summary
+ RO.Count: '# [a]',
+ RO.Mode: 'mode [a]',
+ # Statistics
+ RO.Mean: 'μ [a]',
+ RO.Std: 'σ [a]',
+ RO.Var: 'σ² [a]',
+ RO.StdErr: 'stderr [a]',
+ RO.Min: 'min [a]',
+ RO.Q25: 'q₂₅ [a]',
+ RO.Median: 'median [a]',
+ RO.Q75: 'q₇₅ [a]',
+ RO.Min: 'max [a]',
+ # Reductions
+ RO.Sum: 'sum [a]',
+ RO.Prod: 'prod [a]',
+ }[value]
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ """No icons."""
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
+ RO = ReduceOperation
+ return (
+ str(self),
+ RO.to_name(self),
+ RO.to_name(self),
+ RO.to_icon(self),
+ i,
+ )
+
+ ####################
+ # - Derivation
+ ####################
+ @staticmethod
+ def from_info(info: ct.InfoFlow) -> list[typ.Self]:
+ """Derive valid reduction operations from the `InfoFlow` of the operand."""
+ pass
+
+ ####################
+ # - Composable Functions
+ ####################
+ @property
+ def jax_func(self):
+ RO = ReduceOperation
+ return {}[self]
+
+ ####################
+ # - Transforms
+ ####################
+ def transform_info(self, info: ct.InfoFlow):
+ pass
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py
new file mode 100644
index 0000000..499c082
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py
@@ -0,0 +1,336 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import jax.numpy as jnp
+import jaxtyping as jtyp
+
+from blender_maxwell.utils import logger, sci_constants, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+
+from .. import contracts as ct
+
+log = logger.get(__name__)
+
+
+class TransformOperation(enum.StrEnum):
+ """Valid operations for the `TransformMathNode`.
+
+ Attributes:
+ FreqToVacWL: Transform an frequency dimension to vacuum wavelength.
+ VacWLToFreq: Transform a vacuum wavelength dimension to frequency.
+ ConvertIdxUnit: Convert the unit of a dimension to a compatible unit.
+ SetIdxUnit: Set all properties of a dimension.
+ FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it.
+ **For 2D integer-indexed data only**.
+
+ IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type.
+ DimToVec: Fold the last dimension into the scalar output, creating a vector output type.
+ DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type.
+ FT: Compute the 1D fourier transform along a dimension.
+ New dimensional bounds are computing using the Nyquist Limit.
+ For higher dimensions, simply repeat along more dimensions.
+ InvFT1D: Compute the inverse 1D fourier transform along a dimension.
+ New dimensional bounds are computing using the Nyquist Limit.
+ For higher dimensions, simply repeat along more dimensions.
+ """
+
+ # Covariant Transform
+ FreqToVacWL = enum.auto()
+ VacWLToFreq = enum.auto()
+ ConvertIdxUnit = enum.auto()
+ SetIdxUnit = enum.auto()
+ FirstColToFirstIdx = enum.auto()
+
+ # Fold
+ IntDimToComplex = enum.auto()
+ DimToVec = enum.auto()
+ DimsToMat = enum.auto()
+
+ # Fourier
+ FT1D = enum.auto()
+ InvFT1D = enum.auto()
+
+ # TODO: Affine
+ ## TODO
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ TO = TransformOperation
+ return {
+ # Covariant Transform
+ TO.FreqToVacWL: '𝑓 → λᵥ',
+ TO.VacWLToFreq: 'λᵥ → 𝑓',
+ TO.ConvertIdxUnit: 'Convert Dim',
+ TO.SetIdxUnit: 'Set Dim',
+ TO.FirstColToFirstIdx: '1st Col → 1st Dim',
+ # Fold
+ TO.IntDimToComplex: '→ ℂ',
+ TO.DimToVec: '→ Vector',
+ TO.DimsToMat: '→ Matrix',
+ ## TODO: Vector to new last-dim integer
+ ## TODO: Matrix to two last-dim integers
+ # Fourier
+ TO.FT1D: 'FT',
+ TO.InvFT1D: 'iFT',
+ }[value]
+
+ @property
+ def name(self) -> str:
+ return TransformOperation.to_name(self)
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ TO = TransformOperation
+ return (
+ str(self),
+ TO.to_name(self),
+ TO.to_name(self),
+ TO.to_icon(self),
+ i,
+ )
+
+ ####################
+ # - Methods
+ ####################
+ def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
+ TO = TransformOperation
+ match self:
+ case TO.FreqToVacWL:
+ return [
+ dim
+ for dim in info.dims
+ if dim.physical_type is spux.PhysicalType.Freq
+ ]
+
+ case TO.VacWLToFreq:
+ return [
+ dim
+ for dim in info.dims
+ if dim.physical_type is spux.PhysicalType.Length
+ ]
+
+ case TO.ConvertIdxUnit:
+ return [
+ dim
+ for dim in info.dims
+ if not info.has_idx_labels(dim)
+ and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None
+ ]
+
+ case TO.SetIdxUnit:
+ return [dim for dim in info.dims if not info.has_idx_labels(dim)]
+
+ ## ColDimToComplex: Implicit Last Dimension
+ ## DimToVec: Implicit Last Dimension
+ ## DimsToMat: Implicit Last 2 Dimensions
+
+ case TO.FT1D | TO.InvFT1D:
+ # Filter by Axis Uniformity
+ ## -> FT requires uniform axis (aka. must be RangeFlow).
+ ## -> NOTE: If FT isn't popping up, check ExtractDataNode.
+ return [dim for dim in info.dims if info.is_idx_uniform(dim)]
+
+ return []
+
+ @staticmethod
+ def by_info(info: ct.InfoFlow) -> list[typ.Self]:
+ TO = TransformOperation
+ operations = []
+
+ # Covariant Transform
+ ## Freq -> VacWL
+ if TO.FreqToVacWL.valid_dims(info):
+ operations += [TO.FreqToVacWL]
+
+ ## VacWL -> Freq
+ if TO.VacWLToFreq.valid_dims(info):
+ operations += [TO.VacWLToFreq]
+
+ ## Convert Index Unit
+ if TO.ConvertIdxUnit.valid_dims(info):
+ operations += [TO.ConvertIdxUnit]
+
+ if TO.SetIdxUnit.valid_dims(info):
+ operations += [TO.SetIdxUnit]
+
+ ## Column to First Index (Array)
+ if (
+ len(info.dims) == 2 # noqa: PLR2004
+ and info.first_dim.mathtype is spux.MathType.Integer
+ and info.last_dim.mathtype is spux.MathType.Integer
+ and info.output.shape_len == 0
+ ):
+ operations += [TO.FirstColToFirstIdx]
+
+ # Fold
+ ## Last Dim -> Complex
+ if (
+ len(info.dims) >= 1
+ and (
+ info.output.mathtype
+ in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
+ )
+ and info.last_dim.mathtype is spux.MathType.Integer
+ and info.has_idx_labels(info.last_dim)
+ and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
+ ):
+ operations += [TO.IntDimToComplex]
+
+ ## Last Dim -> Vector
+ if len(info.dims) >= 1 and info.output.shape_len == 0:
+ operations += [TO.DimToVec]
+
+ ## Last Dim -> Matrix
+ if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
+ operations += [TO.DimsToMat]
+
+ # Fourier
+ if TO.FT1D.valid_dims(info):
+ operations += [TO.FT1D]
+
+ if TO.InvFT1D.valid_dims(info):
+ operations += [TO.InvFT1D]
+
+ return operations
+
+ ####################
+ # - Function Properties
+ ####################
+ def jax_func(self, axis: int | None = None):
+ TO = TransformOperation
+ return {
+ # Covariant Transform
+ ## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
+ TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis),
+ TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis),
+ TO.ConvertIdxUnit: lambda expr: expr,
+ TO.SetIdxUnit: lambda expr: expr,
+ TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
+ # Fold
+ ## -> To Complex: This should generally be a no-op.
+ TO.IntDimToComplex: lambda expr: jnp.squeeze(
+ expr.view(dtype=jnp.complex64), axis=-1
+ ),
+ TO.DimToVec: lambda expr: expr,
+ TO.DimsToMat: lambda expr: expr,
+ # Fourier
+ TO.FT1D: lambda expr: jnp.fft(expr, axis=axis),
+ TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis),
+ }[self]
+
+ def transform_info(
+ self,
+ info: ct.InfoFlow,
+ dim: sim_symbols.SimSymbol | None = None,
+ data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
+ new_dim_name: str | None = None,
+ unit: spux.Unit | None = None,
+ physical_type: spux.PhysicalType | None = None,
+ ) -> ct.InfoFlow:
+ TO = TransformOperation
+ return {
+ # Covariant Transform
+ TO.FreqToVacWL: lambda: info.replace_dim(
+ (f_dim := dim),
+ sim_symbols.wl(unit),
+ info.dims[f_dim].rescale(
+ lambda el: sci_constants.vac_speed_of_light / el,
+ reverse=True,
+ new_unit=unit,
+ ),
+ ),
+ TO.VacWLToFreq: lambda: info.replace_dim(
+ (wl_dim := dim),
+ sim_symbols.freq(unit),
+ info.dims[wl_dim].rescale(
+ lambda el: sci_constants.vac_speed_of_light / el,
+ reverse=True,
+ new_unit=unit,
+ ),
+ ),
+ TO.ConvertIdxUnit: lambda: info.replace_dim(
+ dim,
+ dim.update(unit=unit),
+ (
+ info.dims[dim].rescale_to_unit(unit)
+ if info.has_idx_discrete(dim)
+ else None ## Continuous -- dim SimSymbol already scaled
+ ),
+ ),
+ TO.SetIdxUnit: lambda: info.replace_dim(
+ dim,
+ dim.update(
+ sym_name=new_dim_name,
+ physical_type=physical_type,
+ unit=unit,
+ ),
+ (
+ info.dims[dim].correct_unit(unit)
+ if info.has_idx_discrete(dim)
+ else None ## Continuous -- dim SimSymbol already scaled
+ ),
+ ),
+ TO.FirstColToFirstIdx: lambda: info.replace_dim(
+ info.first_dim,
+ info.first_dim.update(
+ sym_name=new_dim_name,
+ mathtype=spux.MathType.from_jax_array(data_col),
+ physical_type=physical_type,
+ unit=unit,
+ ),
+ ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
+ ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
+ # Fold
+ TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
+ mathtype=spux.MathType.Complex
+ ),
+ TO.DimToVec: lambda: info.fold_last_input(),
+ TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
+ # Fourier
+ TO.FT1D: lambda: info.replace_dim(
+ dim,
+ [
+ # FT'ed Unit: Reciprocal of the Original Unit
+ dim.update(
+ unit=1 / dim.unit if dim.unit is not None else 1
+ ), ## TODO: Okay to not scale interval?
+ # FT'ed Bounds: Reciprocal of the Original Unit
+ info.dims[dim].bound_fourier_transform,
+ ],
+ ),
+ TO.InvFT1D: lambda: info.replace_dim(
+ info.last_dim,
+ [
+ # FT'ed Unit: Reciprocal of the Original Unit
+ dim.update(
+ unit=1 / dim.unit if dim.unit is not None else 1
+ ), ## TODO: Okay to not scale interval?
+ # FT'ed Bounds: Reciprocal of the Original Unit
+ ## -> Note the midpoint may revert to 0.
+ ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
+ info.dims[dim].bound_inv_fourier_transform,
+ ],
+ ),
+ }[self]()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
index e036857..a2f54db 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
@@ -26,7 +26,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
index eddeff4..6c02ec1 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
@@ -20,225 +20,15 @@ import enum
import typing as typ
import bpy
-import jax.lax as jlax
-import jax.numpy as jnp
import sympy as sp
-from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import bl_cache, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
-from .... import sockets
+from .... import math_system, sockets
from ... import base, events
-log = logger.get(__name__)
-
-
-class FilterOperation(enum.StrEnum):
- """Valid operations for the `FilterMathNode`.
-
- Attributes:
- DimToVec: Shift last dimension to output.
- DimsToMat: Shift last 2 dimensions to output.
- PinLen1: Remove a len(1) dimension.
- Pin: Remove a len(n) dimension by selecting a particular index.
- Swap: Swap the positions of two dimensions.
- """
-
- # Slice
- Slice = enum.auto()
- SliceIdx = enum.auto()
-
- # Pin
- PinLen1 = enum.auto()
- Pin = enum.auto()
- PinIdx = enum.auto()
-
- # Dimension
- Swap = enum.auto()
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(value: typ.Self) -> str:
- FO = FilterOperation
- return {
- # Slice
- FO.Slice: '≈a[v₁:v₂]',
- FO.SliceIdx: '=a[i:j]',
- # Pin
- FO.PinLen1: 'a[0] → a',
- FO.Pin: 'a[v] ⇝ a',
- FO.PinIdx: 'a[i] → a',
- # Reinterpret
- FO.Swap: 'a₁ ↔ a₂',
- }[value]
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- FO = FilterOperation
- return (
- str(self),
- FO.to_name(self),
- FO.to_name(self),
- FO.to_icon(self),
- i,
- )
-
- ####################
- # - Ops from Info
- ####################
- @staticmethod
- def by_info(info: ct.InfoFlow) -> list[typ.Self]:
- FO = FilterOperation
- operations = []
-
- # Slice
- if info.dims:
- operations.append(FO.SliceIdx)
-
- # Pin
- ## PinLen1
- ## -> There must be a dimension with length 1.
- if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
- operations.append(FO.PinLen1)
-
- ## Pin | PinIdx
- ## -> There must be a dimension, full stop.
- if info.dims:
- operations += [FO.Pin, FO.PinIdx]
-
- # Reinterpret
- ## Swap
- ## -> There must be at least two dimensions.
- if len(info.dims) >= 2: # noqa: PLR2004
- operations.append(FO.Swap)
-
- return operations
-
- ####################
- # - Computed Properties
- ####################
- @property
- def func_args(self) -> list[sim_symbols.SimSymbol]:
- FO = FilterOperation
- return {
- # Pin
- FO.Pin: [sim_symbols.idx(None)],
- FO.PinIdx: [sim_symbols.idx(None)],
- }.get(self, [])
-
- ####################
- # - Methods
- ####################
- @property
- def num_dim_inputs(self) -> None:
- FO = FilterOperation
- return {
- # Slice
- FO.Slice: 1,
- FO.SliceIdx: 1,
- # Pin
- FO.PinLen1: 1,
- FO.Pin: 1,
- FO.PinIdx: 1,
- # Reinterpret
- FO.Swap: 2,
- }[self]
-
- def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
- FO = FilterOperation
- match self:
- # Slice
- case FO.Slice:
- return [dim for dim in info.dims if not info.has_idx_labels(dim)]
-
- case FO.SliceIdx:
- return [dim for dim in info.dims if not info.has_idx_labels(dim)]
-
- # Pin
- case FO.PinLen1:
- return [
- dim
- for dim, dim_idx in info.dims.items()
- if not info.has_idx_cont(dim) and len(dim_idx) == 1
- ]
-
- case FO.Pin:
- return info.dims
-
- case FO.PinIdx:
- return [dim for dim in info.dims if not info.has_idx_cont(dim)]
-
- # Dimension
- case FO.Swap:
- return info.dims
-
- return []
-
- def are_dims_valid(
- self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
- ) -> bool:
- """Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
- if self.num_dim_inputs == 1:
- return dim_0 in self.valid_dims(info)
-
- if self.num_dim_inputs == 2: # noqa: PLR2004
- valid_dims = self.valid_dims(info)
- return dim_0 in valid_dims and dim_1 in valid_dims
-
- return False
-
- ####################
- # - UI
- ####################
- def jax_func(
- self,
- axis_0: int | None,
- axis_1: int | None,
- slice_tuple: tuple[int, int, int] | None = None,
- ):
- FO = FilterOperation
- return {
- # Pin
- FO.Slice: lambda expr: jlax.slice_in_dim(
- expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
- ),
- FO.SliceIdx: lambda expr: jlax.slice_in_dim(
- expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
- ),
- # Pin
- FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
- FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
- FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
- # Dimension
- FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
- }[self]
-
- def transform_info(
- self,
- info: ct.InfoFlow,
- dim_0: sim_symbols.SimSymbol,
- dim_1: sim_symbols.SimSymbol,
- pin_idx: int | None = None,
- slice_tuple: tuple[int, int, int] | None = None,
- ):
- FO = FilterOperation
- return {
- FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
- FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
- # Pin
- FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
- FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
- FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
- # Reinterpret
- FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
- }[self]()
-
class FilterMathNode(base.MaxwellSimNode):
r"""Applies a function that operates on the shape of the array.
@@ -304,7 +94,7 @@ class FilterMathNode(base.MaxwellSimNode):
####################
# - Properties: Operation
####################
- operation: FilterOperation = bl_cache.BLField(
+ operation: math_system.FilterOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@@ -313,7 +103,9 @@ class FilterMathNode(base.MaxwellSimNode):
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
- for i, operation in enumerate(FilterOperation.by_info(self.expr_info))
+ for i, operation in enumerate(
+ math_system.FilterOperation.by_info(self.expr_info)
+ )
]
return []
@@ -358,7 +150,7 @@ class FilterMathNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self):
- FO = FilterOperation
+ FO = math_system.FilterOperation
match self.operation:
# Slice
case FO.SliceIdx:
@@ -398,7 +190,7 @@ class FilterMathNode(base.MaxwellSimNode):
row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['active_dim_1'], text='')
- if self.operation is FilterOperation.SliceIdx:
+ if self.operation is math_system.FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
####################
@@ -434,7 +226,7 @@ class FilterMathNode(base.MaxwellSimNode):
## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
- props['operation'] is FilterOperation.Pin
+ props['operation'] is math_system.FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
@@ -460,7 +252,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Loose Sockets: Pin Dim by-Value
## -> Works with discrete points / labelled integers.
elif (
- props['operation'] is FilterOperation.PinIdx
+ props['operation'] is math_system.FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
):
@@ -594,7 +386,10 @@ class FilterMathNode(base.MaxwellSimNode):
# Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search.
- if props['operation'] is FilterOperation.Pin and has_pinned_value:
+ if (
+ props['operation'] is math_system.FilterOperation.Pin
+ and has_pinned_value
+ ):
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)
@@ -605,7 +400,10 @@ class FilterMathNode(base.MaxwellSimNode):
)
# Pin by-Index
- if props['operation'] is FilterOperation.PinIdx and has_pinned_axis:
+ if (
+ props['operation'] is math_system.FilterOperation.PinIdx
+ and has_pinned_axis
+ ):
return params.compose_within(
enclosing_arg_targets=[sim_symbols.idx(None)],
enclosing_func_args=[sp.S(pinned_axis)],
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
index 04765a8..4959822 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
@@ -16,363 +16,19 @@
"""Declares `MapMathNode`."""
-import enum
import typing as typ
import bpy
-import jax.numpy as jnp
-import sympy as sp
-from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import bl_cache, logger
from .... import contracts as ct
-from .... import sockets
+from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
-####################
-# - Operation Enum
-####################
-class MapOperation(enum.StrEnum):
- """Valid operations for the `MapMathNode`.
-
- Attributes:
- Real: Compute the real part of the input.
- Imag: Compute the imaginary part of the input.
- Abs: Compute the absolute value of the input.
- Sq: Square the input.
- Sqrt: Compute the (principal) square root of the input.
- InvSqrt: Compute the inverse square root of the input.
- Cos: Compute the cosine of the input.
- Sin: Compute the sine of the input.
- Tan: Compute the tangent of the input.
- Acos: Compute the inverse cosine of the input.
- Asin: Compute the inverse sine of the input.
- Atan: Compute the inverse tangent of the input.
- Norm2: Compute the 2-norm (aka. length) of the input vector.
- Det: Compute the determinant of the input matrix.
- Cond: Compute the condition number of the input matrix.
- NormFro: Compute the frobenius norm of the input matrix.
- Rank: Compute the rank of the input matrix.
- Diag: Compute the diagonal vector of the input matrix.
- EigVals: Compute the eigenvalues vector of the input matrix.
- SvdVals: Compute the singular values vector of the input matrix.
- Inv: Compute the inverse matrix of the input matrix.
- Tra: Compute the transpose matrix of the input matrix.
- Qr: Compute the QR-factorized matrices of the input matrix.
- Chol: Compute the Cholesky-factorized matrices of the input matrix.
- Svd: Compute the SVD-factorized matrices of the input matrix.
- """
-
- # By Number
- Real = enum.auto()
- Imag = enum.auto()
- Abs = enum.auto()
- Sq = enum.auto()
- Sqrt = enum.auto()
- InvSqrt = enum.auto()
- Cos = enum.auto()
- Sin = enum.auto()
- Tan = enum.auto()
- Acos = enum.auto()
- Asin = enum.auto()
- Atan = enum.auto()
- Sinc = enum.auto()
- # By Vector
- Norm2 = enum.auto()
- # By Matrix
- Det = enum.auto()
- Cond = enum.auto()
- NormFro = enum.auto()
- Rank = enum.auto()
- Diag = enum.auto()
- EigVals = enum.auto()
- SvdVals = enum.auto()
- Inv = enum.auto()
- Tra = enum.auto()
- Qr = enum.auto()
- Chol = enum.auto()
- Svd = enum.auto()
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(value: typ.Self) -> str:
- MO = MapOperation
- return {
- # By Number
- MO.Real: 'ℝ(v)',
- MO.Imag: 'Im(v)',
- MO.Abs: '|v|',
- MO.Sq: 'v²',
- MO.Sqrt: '√v',
- MO.InvSqrt: '1/√v',
- MO.Cos: 'cos v',
- MO.Sin: 'sin v',
- MO.Tan: 'tan v',
- MO.Acos: 'acos v',
- MO.Asin: 'asin v',
- MO.Atan: 'atan v',
- MO.Sinc: 'sinc v',
- # By Vector
- MO.Norm2: '||v||₂',
- # By Matrix
- MO.Det: 'det V',
- MO.Cond: 'κ(V)',
- MO.NormFro: '||V||_F',
- MO.Rank: 'rank V',
- MO.Diag: 'diag V',
- MO.EigVals: 'eigvals V',
- MO.SvdVals: 'svdvals V',
- MO.Inv: 'V⁻¹',
- MO.Tra: 'Vt',
- MO.Qr: 'qr V',
- MO.Chol: 'chol V',
- MO.Svd: 'svd V',
- }[value]
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- MO = MapOperation
- return (
- str(self),
- MO.to_name(self),
- MO.to_name(self),
- MO.to_icon(self),
- i,
- )
-
- ####################
- # - Ops from Shape
- ####################
- @staticmethod
- def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
- ## TODO: By info, not shape.
- ## TODO: Check valid domains/mathtypes for some functions.
- MO = MapOperation
- element_ops = [
- MO.Real,
- MO.Imag,
- MO.Abs,
- MO.Sq,
- MO.Sqrt,
- MO.InvSqrt,
- MO.Cos,
- MO.Sin,
- MO.Tan,
- MO.Acos,
- MO.Asin,
- MO.Atan,
- MO.Sinc,
- ]
-
- match (info.output.rows, info.output.cols):
- case (1, 1):
- return element_ops
-
- case (_, 1):
- return [*element_ops, MO.Norm2]
-
- case (rows, cols) if rows == cols:
- ## TODO: Check hermitian/posdef for cholesky.
- ## - Can we even do this with just the output symbol approach?
- return [
- *element_ops,
- MO.Det,
- MO.Cond,
- MO.NormFro,
- MO.Rank,
- MO.Diag,
- MO.EigVals,
- MO.SvdVals,
- MO.Inv,
- MO.Tra,
- MO.Qr,
- MO.Chol,
- MO.Svd,
- ]
-
- case (rows, cols):
- return [
- *element_ops,
- MO.Cond,
- MO.NormFro,
- MO.Rank,
- MO.SvdVals,
- MO.Inv,
- MO.Tra,
- MO.Svd,
- ]
-
- return []
-
- ####################
- # - Function Properties
- ####################
- @property
- def sp_func(self):
- MO = MapOperation
- return {
- # By Number
- MO.Real: lambda expr: sp.re(expr),
- MO.Imag: lambda expr: sp.im(expr),
- MO.Abs: lambda expr: sp.Abs(expr),
- MO.Sq: lambda expr: expr**2,
- MO.Sqrt: lambda expr: sp.sqrt(expr),
- MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
- MO.Cos: lambda expr: sp.cos(expr),
- MO.Sin: lambda expr: sp.sin(expr),
- MO.Tan: lambda expr: sp.tan(expr),
- MO.Acos: lambda expr: sp.acos(expr),
- MO.Asin: lambda expr: sp.asin(expr),
- MO.Atan: lambda expr: sp.atan(expr),
- MO.Sinc: lambda expr: sp.sinc(expr),
- # By Vector
- # Vector -> Number
- MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0],
- # By Matrix
- # Matrix -> Number
- MO.Det: lambda expr: sp.det(expr),
- MO.Cond: lambda expr: expr.condition_number(),
- MO.NormFro: lambda expr: expr.norm(ord='fro'),
- MO.Rank: lambda expr: expr.rank(),
- # Matrix -> Vec
- MO.Diag: lambda expr: expr.diagonal(),
- MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
- MO.SvdVals: lambda expr: expr.singular_values(),
- # Matrix -> Matrix
- MO.Inv: lambda expr: expr.inv(),
- MO.Tra: lambda expr: expr.T,
- # Matrix -> Matrices
- MO.Qr: lambda expr: expr.QRdecomposition(),
- MO.Chol: lambda expr: expr.cholesky(),
- MO.Svd: lambda expr: expr.singular_value_decomposition(),
- }[self]
-
- @property
- def jax_func(self):
- MO = MapOperation
- return {
- # By Number
- MO.Real: lambda expr: jnp.real(expr),
- MO.Imag: lambda expr: jnp.imag(expr),
- MO.Abs: lambda expr: jnp.abs(expr),
- MO.Sq: lambda expr: jnp.square(expr),
- MO.Sqrt: lambda expr: jnp.sqrt(expr),
- MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
- MO.Cos: lambda expr: jnp.cos(expr),
- MO.Sin: lambda expr: jnp.sin(expr),
- MO.Tan: lambda expr: jnp.tan(expr),
- MO.Acos: lambda expr: jnp.acos(expr),
- MO.Asin: lambda expr: jnp.asin(expr),
- MO.Atan: lambda expr: jnp.atan(expr),
- MO.Sinc: lambda expr: jnp.sinc(expr),
- # By Vector
- # Vector -> Number
- MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
- # By Matrix
- # Matrix -> Number
- MO.Det: lambda expr: jnp.linalg.det(expr),
- MO.Cond: lambda expr: jnp.linalg.cond(expr),
- MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
- MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
- # Matrix -> Vec
- MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
- MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
- MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
- # Matrix -> Matrix
- MO.Inv: lambda expr: jnp.linalg.inv(expr),
- MO.Tra: lambda expr: jnp.matrix_transpose(expr),
- # Matrix -> Matrices
- MO.Qr: lambda expr: jnp.linalg.qr(expr),
- MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
- MO.Svd: lambda expr: jnp.linalg.svd(expr),
- }[self]
-
- def transform_info(self, info: ct.InfoFlow):
- MO = MapOperation
-
- return {
- # By Number
- MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Sq: lambda: info,
- MO.Sqrt: lambda: info,
- MO.InvSqrt: lambda: info,
- MO.Cos: lambda: info,
- MO.Sin: lambda: info,
- MO.Tan: lambda: info,
- MO.Acos: lambda: info,
- MO.Asin: lambda: info,
- MO.Atan: lambda: info,
- MO.Sinc: lambda: info,
- # By Vector
- MO.Norm2: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- # Interval
- interval_finite_re=(0, sim_symbols.float_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- # By Matrix
- MO.Det: lambda: info.update_output(
- rows=1,
- cols=1,
- ),
- MO.Cond: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- physical_type=spux.PhysicalType.NonPhysical,
- unit=None,
- ),
- MO.NormFro: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- # Interval
- interval_finite_re=(0, sim_symbols.float_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- MO.Rank: lambda: info.update_output(
- mathtype=spux.MathType.Integer,
- rows=1,
- cols=1,
- physical_type=spux.PhysicalType.NonPhysical,
- unit=None,
- # Interval
- interval_finite_re=(0, sim_symbols.int_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- # Matrix -> Vector ## TODO: ALL OF THESE
- MO.Diag: lambda: info,
- MO.EigVals: lambda: info,
- MO.SvdVals: lambda: info,
- # Matrix -> Matrix ## TODO: ALL OF THESE
- MO.Inv: lambda: info,
- MO.Tra: lambda: info,
- # Matrix -> Matrices ## TODO: ALL OF THESE
- MO.Qr: lambda: info,
- MO.Chol: lambda: info,
- MO.Svd: lambda: info,
- }[self]()
-
-
-####################
-# - Node
-####################
class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data.
@@ -495,7 +151,7 @@ class MapMathNode(base.MaxwellSimNode):
return info
return None
- operation: MapOperation = bl_cache.BLField(
+ operation: math_system.MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@@ -504,7 +160,9 @@ class MapMathNode(base.MaxwellSimNode):
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
- for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
+ for i, operation in enumerate(
+ math_system.MapOperation.by_expr_info(self.expr_info)
+ )
]
return []
@@ -513,7 +171,7 @@ class MapMathNode(base.MaxwellSimNode):
####################
def draw_label(self):
if self.operation is not None:
- return 'Map: ' + MapOperation.to_name(self.operation)
+ return 'Map: ' + math_system.MapOperation.to_name(self.operation)
return self.bl_label
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
index f59bde4..920f863 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
@@ -14,354 +14,24 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import enum
+"""Implements the `OperateMathNode`.
+
+See `blender_maxwell.maxwell_sim_nodes.math_system` for the actual mathematics implementation.
+"""
+
import typing as typ
import bpy
-import jax.numpy as jnp
-import sympy as sp
-import sympy.physics.quantum as spq
-import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
-from .... import sockets
+from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
-####################
-# - Operation Enum
-####################
-class BinaryOperation(enum.StrEnum):
- """Valid operations for the `OperateMathNode`.
-
- Attributes:
- Mul: Scalar multiplication.
- Div: Scalar division.
- Pow: Scalar exponentiation.
- Add: Elementwise addition.
- Sub: Elementwise subtraction.
- HadamMul: Elementwise multiplication (hadamard product).
- HadamPow: Principled shape-aware exponentiation (hadamard power).
- Atan2: Quadrant-respecting 2D arctangent.
- VecVecDot: Dot product for identically shaped vectors w/transpose.
- Cross: Cross product between identically shaped 3D vectors.
- VecVecOuter: Vector-vector outer product.
- LinSolve: Solve a linear system.
- LsqSolve: Minimize error of an underdetermined linear system.
- VecMatOuter: Vector-matrix outer product.
- MatMatDot: Matrix-matrix dot product.
- """
-
- # Number | Number
- Mul = enum.auto()
- Div = enum.auto()
- Pow = enum.auto()
-
- # Elements | Elements
- Add = enum.auto()
- Sub = enum.auto()
- HadamMul = enum.auto()
- # HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
- Atan2 = enum.auto()
-
- # Vector | Vector
- VecVecDot = enum.auto()
- Cross = enum.auto()
- VecVecOuter = enum.auto()
-
- # Matrix | Vector
- LinSolve = enum.auto()
- LsqSolve = enum.auto()
-
- # Vector | Matrix
- VecMatOuter = enum.auto()
-
- # Matrix | Matrix
- MatMatDot = enum.auto()
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(value: typ.Self) -> str:
- BO = BinaryOperation
- return {
- # Number | Number
- BO.Mul: 'ℓ · r',
- BO.Div: 'ℓ / r',
- BO.Pow: 'ℓ ^ r',
- # Elements | Elements
- BO.Add: 'ℓ + r',
- BO.Sub: 'ℓ - r',
- BO.HadamMul: '𝐋 ⊙ 𝐑',
- # BO.HadamPow: '𝐥 ⊙^ 𝐫',
- BO.Atan2: 'atan2(ℓ:x, r:y)',
- # Vector | Vector
- BO.VecVecDot: '𝐥 · 𝐫',
- BO.Cross: 'cross(𝐥,𝐫)',
- BO.VecVecOuter: '𝐥 ⊗ 𝐫',
- # Matrix | Vector
- BO.LinSolve: '𝐋 ∖ 𝐫',
- BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
- # Vector | Matrix
- BO.VecMatOuter: '𝐋 ⊗ 𝐫',
- # Matrix | Matrix
- BO.MatMatDot: '𝐋 · 𝐑',
- }[value]
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- BO = BinaryOperation
- return (
- str(self),
- BO.to_name(self),
- BO.to_name(self),
- BO.to_icon(self),
- i,
- )
-
- ####################
- # - Ops from Shape
- ####################
- @staticmethod
- def by_infos(info_l: int, info_r: int) -> list[typ.Self]:
- """Deduce valid binary operations from the shapes of the inputs."""
- BO = BinaryOperation
-
- ops_el_el = [
- BO.Add,
- BO.Sub,
- BO.HadamMul,
- # BO.HadamPow,
- ]
-
- outl = info_l.output
- outr = info_r.output
- match (outl.shape_len, outr.shape_len):
- # Number | *
- ## Number | Number
- case (0, 0):
- ops = [
- BO.Add,
- BO.Sub,
- BO.Mul,
- ]
-
- # Check Non-Zero Right Hand Side
- ## -> Obviously, we can't ever divide by zero.
- ## -> Sympy's assumptions system must always guarantee rhs != 0.
- ## -> If it can't, then we simply don't expose division.
- ## -> The is_zero assumption must be provided elsewhere.
- ## -> NOTE: This may prevent some valid uses of division.
- ## -> Watch out for "division is missing" bugs.
- if info_r.output.is_nonzero:
- ops.append(BO.Div)
-
- if (
- info_l.output.physical_type == spux.PhysicalType.Length
- and info_l.output.unit == info_r.output.unit
- ):
- ops += [BO.Atan2]
-
- return [*ops, BO.Pow]
-
- ## Number | Vector
- case (0, 1):
- return [BO.Mul] # , BO.HadamPow]
-
- ## Number | Matrix
- case (0, 2):
- return [BO.Mul] # , BO.HadamPow]
-
- # Vector | *
- ## Vector | Number
- case (1, 0):
- return [BO.Mul] # , BO.HadamPow]
-
- ## Vector | Vector
- case (1, 1):
- ops = []
-
- # Vector | Vector
- ## -> Dot: Convenience; utilize special vec-vec dot w/transp.
- if outl.rows > outl.cols and outr.rows > outr.cols:
- ops += [BO.VecVecDot, BO.VecVecOuter]
-
- # Covector | Vector
- ## -> Dot: Directly use matrix-matrix dot, as it's now correct.
- if outl.rows < outl.cols and outr.rows > outr.cols:
- ops += [BO.MatMatDot, BO.VecVecOuter]
-
- # Vector | Covector
- ## -> Dot: Directly use matrix-matrix dot, as it's now correct.
- ## -> These are both the same operation, in this case.
- if outl.rows > outl.cols and outr.rows < outr.cols:
- ops += [BO.MatMatDot, BO.VecVecOuter]
-
- # Covector | Covector
- ## -> Dot: Convenience; utilize special vec-vec dot w/transp.
- if outl.rows < outl.cols and outr.rows < outr.cols:
- ops += [BO.VecVecDot, BO.VecVecOuter]
-
- # Cross Product
- ## -> Enforce that both are 3x1 or 1x3.
- ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
- if (outl.rows == 3 and outr.rows == 3) or (
- outl.cols == 3 and outl.cols == 3
- ):
- ops += [BO.Cross]
-
- return ops_el_el + ops
-
- ## Vector | Matrix
- case (1, 2):
- return [BO.VecMatOuter]
-
- # Matrix | *
- ## Matrix | Number
- case (2, 0):
- return [BO.Mul] # , BO.HadamPow]
-
- ## Matrix | Vector
- case (2, 1):
- prepend_ops = []
-
- # Mat-Vec Dot: Enforce RHS Column Vector
- if outr.rows > outl.cols:
- prepend_ops += [BO.MatMatDot]
-
- return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
-
- ## Matrix | Matrix
- case (2, 2):
- return [*ops_el_el, BO.MatMatDot]
-
- return []
-
- ####################
- # - Function Properties
- ####################
- @property
- def sp_func(self):
- """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
- BO = BinaryOperation
-
- ## TODO: Make this compatible with sp.Matrix inputs
- return {
- # Number | Number
- BO.Mul: lambda exprs: exprs[0] * exprs[1],
- BO.Div: lambda exprs: exprs[0] / exprs[1],
- BO.Pow: lambda exprs: exprs[0] ** exprs[1],
- # Elements | Elements
- BO.Add: lambda exprs: exprs[0] + exprs[1],
- BO.Sub: lambda exprs: exprs[0] - exprs[1],
- BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
- # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
- BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
- # Vector | Vector
- BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
- BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
- BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
- # Matrix | Vector
- BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
- BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
- # Vector | Matrix
- BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
- # Matrix | Matrix
- BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
- }[self]
-
- @property
- def unit_func(self):
- """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
- BO = BinaryOperation
-
- ## TODO: Make this compatible with sp.Matrix inputs
- return {
- # Number | Number
- BO.Mul: BO.Mul.sp_func,
- BO.Div: BO.Div.sp_func,
- BO.Pow: BO.Pow.sp_func,
- # Elements | Elements
- BO.Add: BO.Add.sp_func,
- BO.Sub: BO.Sub.sp_func,
- BO.HadamMul: BO.Mul.sp_func,
- # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
- BO.Atan2: lambda _: spu.radian,
- # Vector | Vector
- BO.VecVecDot: BO.Mul.sp_func,
- BO.Cross: BO.Mul.sp_func,
- BO.VecVecOuter: BO.Mul.sp_func,
- # Matrix | Vector
- ## -> A,b in Ax = b have units, and the equality must hold.
- ## -> Therefore, A \ b must have the units [b]/[A].
- BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
- BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
- # Vector | Matrix
- BO.VecMatOuter: BO.Mul.sp_func,
- # Matrix | Matrix
- BO.MatMatDot: BO.Mul.sp_func,
- }[self]
-
- @property
- def jax_func(self):
- """Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
- ## TODO: Scale the units of one side to the other.
- BO = BinaryOperation
-
- return {
- # Number | Number
- BO.Mul: lambda exprs: exprs[0] * exprs[1],
- BO.Div: lambda exprs: exprs[0] / exprs[1],
- BO.Pow: lambda exprs: exprs[0] ** exprs[1],
- # Elements | Elements
- BO.Add: lambda exprs: exprs[0] + exprs[1],
- BO.Sub: lambda exprs: exprs[0] - exprs[1],
- BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
- # BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
- BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
- # Vector | Vector
- BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
- BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
- BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
- # Matrix | Vector
- BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
- BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
- # Vector | Matrix
- BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
- # Matrix | Matrix
- BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
- }[self]
-
- ####################
- # - InfoFlow Transform
- ####################
- def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
- """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
-
- Warnings:
- `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
-
- If not, bad things will happen.
- """
- return info_l.operate_output(
- info_r,
- lambda a, b: self.sp_func([a, b]),
- lambda a, b: self.unit_func([a, b]),
- )
-
-
-####################
-# - Node
-####################
class OperateMathNode(base.MaxwellSimNode):
r"""Applies a binary function between two expressions.
@@ -386,7 +56,7 @@ class OperateMathNode(base.MaxwellSimNode):
}
####################
- # - Properties
+ # - Properties: Incoming InfoFlows
####################
@events.on_value_changed(
# Trigger
@@ -415,6 +85,7 @@ class OperateMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None:
+ """Computed `InfoFlow`s of both expressions."""
info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info)
info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info)
@@ -426,19 +97,18 @@ class OperateMathNode(base.MaxwellSimNode):
return None
- operation: BinaryOperation = bl_cache.BLField(
+ ####################
+ # - Property: Operation
+ ####################
+ operation: math_system.BinaryOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_infos'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
+ """Retrieve valid operations based on the input `InfoFlow`s."""
if self.expr_infos is not None:
- return [
- operation.bl_enum_element(i)
- for i, operation in enumerate(
- BinaryOperation.by_infos(*self.expr_infos)
- )
- ]
+ return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos)
return []
####################
@@ -451,7 +121,7 @@ class OperateMathNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
- return 'Op: ' + BinaryOperation.to_name(self.operation)
+ return 'Op: ' + math_system.BinaryOperation.to_name(self.operation)
return self.bl_label
@@ -464,7 +134,7 @@ class OperateMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='')
####################
- # - FlowKind.Value|Func
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Expr',
@@ -477,20 +147,22 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_value(self, props: dict, input_sockets: dict):
- operation = props['operation']
+ """Binary operation on two symbolic input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
has_expr_l_value = not ct.FlowSignal.check(expr_l)
has_expr_r_value = not ct.FlowSignal.check(expr_r)
- # Compute Sympy Function
- ## -> The operation enum directly provides the appropriate function.
+ operation = props['operation']
if has_expr_l_value and has_expr_r_value and operation is not None:
return operation.sp_func([expr_l, expr_r])
return ct.FlowSignal.FlowPending
+ ####################
+ # - FlowKind.Func
+ ####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@@ -505,10 +177,7 @@ class OperateMathNode(base.MaxwellSimNode):
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(self, props, input_sockets, output_sockets):
- operation = props['operation']
- if operation is None:
- return ct.FlowSignal.FlowPending
-
+ """Binary operation on two lazy-defined input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
output_info = output_sockets['Expr']
@@ -517,14 +186,9 @@ class OperateMathNode(base.MaxwellSimNode):
has_expr_r = not ct.FlowSignal.check(expr_r)
has_output_info = not ct.FlowSignal.check(output_info)
- # Compute Jax Function
- ## -> The operation enum directly provides the appropriate function.
- if has_expr_l and has_expr_r and has_output_info:
- return (expr_l | expr_r).compose_within(
- operation.jax_func,
- enclosing_func_output=output_info.output,
- supports_jax=True,
- )
+ operation = props['operation']
+ if operation is not None and has_expr_l and has_expr_r and has_output_info:
+ return self.operation.transform_funcs(expr_l, expr_r)
return ct.FlowSignal.FlowPending
####################
@@ -541,22 +205,17 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
- BO = BinaryOperation
-
- operation = props['operation']
+ """Transform the input information of both lazy inputs."""
info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R']
has_info_l = not ct.FlowSignal.check(info_l)
has_info_r = not ct.FlowSignal.check(info_r)
- # Compute Info
- ## -> The operation enum directly provides the appropriate transform.
+ operation = props['operation']
if (
- has_info_l
- and has_info_r
- and operation is not None
- and operation in BO.by_infos(info_l, info_r)
+ has_info_l and has_info_r and operation is not None
+ # and operation in BO.by_infos(info_l, info_r)
):
return operation.transform_infos(info_l, info_r)
@@ -576,15 +235,14 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
- operation = props['operation']
+ """Merge the lazy input parameters."""
params_l = input_sockets['Expr L']
params_r = input_sockets['Expr R']
has_params_l = not ct.FlowSignal.check(params_l)
has_params_r = not ct.FlowSignal.check(params_r)
- # Compute Params
- ## -> Operations don't add new parameters, so just concatenate L|R.
+ operation = props['operation']
if has_params_l and has_params_r and operation is not None:
return params_l | params_r
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
index 6744f7d..aa6ed07 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
@@ -20,334 +20,18 @@ import enum
import typing as typ
import bpy
-import jax.numpy as jnp
-import jaxtyping as jtyp
import sympy as sp
-from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
-from .... import sockets
+from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
-####################
-# - Operation Enum
-####################
-class TransformOperation(enum.StrEnum):
- """Valid operations for the `TransformMathNode`.
-
- Attributes:
- FreqToVacWL: Transform an frequency dimension to vacuum wavelength.
- VacWLToFreq: Transform a vacuum wavelength dimension to frequency.
- ConvertIdxUnit: Convert the unit of a dimension to a compatible unit.
- SetIdxUnit: Set all properties of a dimension.
- FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it.
- **For 2D integer-indexed data only**.
-
- IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type.
- DimToVec: Fold the last dimension into the scalar output, creating a vector output type.
- DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type.
- FT: Compute the 1D fourier transform along a dimension.
- New dimensional bounds are computing using the Nyquist Limit.
- For higher dimensions, simply repeat along more dimensions.
- InvFT1D: Compute the inverse 1D fourier transform along a dimension.
- New dimensional bounds are computing using the Nyquist Limit.
- For higher dimensions, simply repeat along more dimensions.
- """
-
- # Covariant Transform
- FreqToVacWL = enum.auto()
- VacWLToFreq = enum.auto()
- ConvertIdxUnit = enum.auto()
- SetIdxUnit = enum.auto()
- FirstColToFirstIdx = enum.auto()
-
- # Fold
- IntDimToComplex = enum.auto()
- DimToVec = enum.auto()
- DimsToMat = enum.auto()
-
- # Fourier
- FT1D = enum.auto()
- InvFT1D = enum.auto()
-
- # TODO: Affine
- ## TODO
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(value: typ.Self) -> str:
- TO = TransformOperation
- return {
- # Covariant Transform
- TO.FreqToVacWL: '𝑓 → λᵥ',
- TO.VacWLToFreq: 'λᵥ → 𝑓',
- TO.ConvertIdxUnit: 'Convert Dim',
- TO.SetIdxUnit: 'Set Dim',
- TO.FirstColToFirstIdx: '1st Col → 1st Dim',
- # Fold
- TO.IntDimToComplex: '→ ℂ',
- TO.DimToVec: '→ Vector',
- TO.DimsToMat: '→ Matrix',
- ## TODO: Vector to new last-dim integer
- ## TODO: Matrix to two last-dim integers
- # Fourier
- TO.FT1D: 'FT',
- TO.InvFT1D: 'iFT',
- }[value]
-
- @property
- def name(self) -> str:
- return TransformOperation.to_name(self)
-
- @staticmethod
- def to_icon(_: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- TO = TransformOperation
- return (
- str(self),
- TO.to_name(self),
- TO.to_name(self),
- TO.to_icon(self),
- i,
- )
-
- ####################
- # - Methods
- ####################
- def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
- TO = TransformOperation
- match self:
- case TO.FreqToVacWL:
- return [
- dim
- for dim in info.dims
- if dim.physical_type is spux.PhysicalType.Freq
- ]
-
- case TO.VacWLToFreq:
- return [
- dim
- for dim in info.dims
- if dim.physical_type is spux.PhysicalType.Length
- ]
-
- case TO.ConvertIdxUnit:
- return [
- dim
- for dim in info.dims
- if not info.has_idx_labels(dim)
- and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None
- ]
-
- case TO.SetIdxUnit:
- return [dim for dim in info.dims if not info.has_idx_labels(dim)]
-
- ## ColDimToComplex: Implicit Last Dimension
- ## DimToVec: Implicit Last Dimension
- ## DimsToMat: Implicit Last 2 Dimensions
-
- case TO.FT1D | TO.InvFT1D:
- # Filter by Axis Uniformity
- ## -> FT requires uniform axis (aka. must be RangeFlow).
- ## -> NOTE: If FT isn't popping up, check ExtractDataNode.
- return [dim for dim in info.dims if info.is_idx_uniform(dim)]
-
- return []
-
- @staticmethod
- def by_info(info: ct.InfoFlow) -> list[typ.Self]:
- TO = TransformOperation
- operations = []
-
- # Covariant Transform
- ## Freq -> VacWL
- if TO.FreqToVacWL.valid_dims(info):
- operations += [TO.FreqToVacWL]
-
- ## VacWL -> Freq
- if TO.VacWLToFreq.valid_dims(info):
- operations += [TO.VacWLToFreq]
-
- ## Convert Index Unit
- if TO.ConvertIdxUnit.valid_dims(info):
- operations += [TO.ConvertIdxUnit]
-
- if TO.SetIdxUnit.valid_dims(info):
- operations += [TO.SetIdxUnit]
-
- ## Column to First Index (Array)
- if (
- len(info.dims) == 2 # noqa: PLR2004
- and info.first_dim.mathtype is spux.MathType.Integer
- and info.last_dim.mathtype is spux.MathType.Integer
- and info.output.shape_len == 0
- ):
- operations += [TO.FirstColToFirstIdx]
-
- # Fold
- ## Last Dim -> Complex
- if (
- len(info.dims) >= 1
- and (
- info.output.mathtype
- in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
- )
- and info.last_dim.mathtype is spux.MathType.Integer
- and info.has_idx_labels(info.last_dim)
- and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
- ):
- operations += [TO.IntDimToComplex]
-
- ## Last Dim -> Vector
- if len(info.dims) >= 1 and info.output.shape_len == 0:
- operations += [TO.DimToVec]
-
- ## Last Dim -> Matrix
- if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
- operations += [TO.DimsToMat]
-
- # Fourier
- if TO.FT1D.valid_dims(info):
- operations += [TO.FT1D]
-
- if TO.InvFT1D.valid_dims(info):
- operations += [TO.InvFT1D]
-
- return operations
-
- ####################
- # - Function Properties
- ####################
- def jax_func(self, axis: int | None = None):
- TO = TransformOperation
- return {
- # Covariant Transform
- ## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
- TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis),
- TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis),
- TO.ConvertIdxUnit: lambda expr: expr,
- TO.SetIdxUnit: lambda expr: expr,
- TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
- # Fold
- ## -> To Complex: This should generally be a no-op.
- TO.IntDimToComplex: lambda expr: jnp.squeeze(
- expr.view(dtype=jnp.complex64), axis=-1
- ),
- TO.DimToVec: lambda expr: expr,
- TO.DimsToMat: lambda expr: expr,
- # Fourier
- TO.FT1D: lambda expr: jnp.fft(expr, axis=axis),
- TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis),
- }[self]
-
- def transform_info(
- self,
- info: ct.InfoFlow,
- dim: sim_symbols.SimSymbol | None = None,
- data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
- new_dim_name: str | None = None,
- unit: spux.Unit | None = None,
- physical_type: spux.PhysicalType | None = None,
- ) -> ct.InfoFlow:
- TO = TransformOperation
- return {
- # Covariant Transform
- TO.FreqToVacWL: lambda: info.replace_dim(
- (f_dim := dim),
- sim_symbols.wl(unit),
- info.dims[f_dim].rescale(
- lambda el: sci_constants.vac_speed_of_light / el,
- reverse=True,
- new_unit=unit,
- ),
- ),
- TO.VacWLToFreq: lambda: info.replace_dim(
- (wl_dim := dim),
- sim_symbols.freq(unit),
- info.dims[wl_dim].rescale(
- lambda el: sci_constants.vac_speed_of_light / el,
- reverse=True,
- new_unit=unit,
- ),
- ),
- TO.ConvertIdxUnit: lambda: info.replace_dim(
- dim,
- dim.update(unit=unit),
- (
- info.dims[dim].rescale_to_unit(unit)
- if info.has_idx_discrete(dim)
- else None ## Continuous -- dim SimSymbol already scaled
- ),
- ),
- TO.SetIdxUnit: lambda: info.replace_dim(
- dim,
- dim.update(
- sym_name=new_dim_name,
- physical_type=physical_type,
- unit=unit,
- ),
- (
- info.dims[dim].correct_unit(unit)
- if info.has_idx_discrete(dim)
- else None ## Continuous -- dim SimSymbol already scaled
- ),
- ),
- TO.FirstColToFirstIdx: lambda: info.replace_dim(
- info.first_dim,
- info.first_dim.update(
- sym_name=new_dim_name,
- mathtype=spux.MathType.from_jax_array(data_col),
- physical_type=physical_type,
- unit=unit,
- ),
- ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
- ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
- # Fold
- TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
- mathtype=spux.MathType.Complex
- ),
- TO.DimToVec: lambda: info.fold_last_input(),
- TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
- # Fourier
- TO.FT1D: lambda: info.replace_dim(
- dim,
- [
- # FT'ed Unit: Reciprocal of the Original Unit
- dim.update(
- unit=1 / dim.unit if dim.unit is not None else 1
- ), ## TODO: Okay to not scale interval?
- # FT'ed Bounds: Reciprocal of the Original Unit
- info.dims[dim].bound_fourier_transform,
- ],
- ),
- TO.InvFT1D: lambda: info.replace_dim(
- info.last_dim,
- [
- # FT'ed Unit: Reciprocal of the Original Unit
- dim.update(
- unit=1 / dim.unit if dim.unit is not None else 1
- ), ## TODO: Okay to not scale interval?
- # FT'ed Bounds: Reciprocal of the Original Unit
- ## -> Note the midpoint may revert to 0.
- ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
- info.dims[dim].bound_inv_fourier_transform,
- ],
- ),
- }[self]()
-
-
-####################
-# - Node
-####################
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
@@ -409,7 +93,7 @@ class TransformMathNode(base.MaxwellSimNode):
####################
# - Properties: Operation
####################
- operation: TransformOperation = bl_cache.BLField(
+ operation: math_system.TransformOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@@ -419,7 +103,7 @@ class TransformMathNode(base.MaxwellSimNode):
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
- TransformOperation.by_info(self.expr_info)
+ math_system.TransformOperation.by_info(self.expr_info)
)
]
return []
@@ -461,7 +145,7 @@ class TransformMathNode(base.MaxwellSimNode):
)
def search_units(self) -> list[ct.BLEnumElement]:
- TO = TransformOperation
+ TO = math_system.TransformOperation
match self.operation:
# Covariant Transform
case TO.ConvertIdxUnit if self.dim is not None:
@@ -521,7 +205,7 @@ class TransformMathNode(base.MaxwellSimNode):
return spux.sp_to_str(self.new_unit)
def draw_label(self):
- TO = TransformOperation
+ TO = math_system.TransformOperation
match self.operation:
case TO.FreqToVacWL if self.dim is not None:
return f'T: {self.dim.name_pretty} | 𝑓 → {self.new_unit_str}'
@@ -556,7 +240,7 @@ class TransformMathNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
- TO = TransformOperation
+ TO = math_system.TransformOperation
match self.operation:
case TO.ConvertIdxUnit:
row = layout.row(align=True)
@@ -613,7 +297,7 @@ class TransformMathNode(base.MaxwellSimNode):
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
"""Transform the input `InfoFlow` depending on the transform operation."""
- TO = TransformOperation
+ TO = math_system.TransformOperation
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
@@ -662,7 +346,7 @@ class TransformMathNode(base.MaxwellSimNode):
self, props: dict, input_sockets: dict
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
"""Transform the input `InfoFlow` depending on the transform operation."""
- TO = TransformOperation
+ TO = math_system.TransformOperation
operation = props['operation']
info = input_sockets['Expr'][ct.FlowKind.Info]
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
index ad44ef8..dda78ff 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
@@ -24,7 +24,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py
index 2f8f276..f03eb77 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py
@@ -22,7 +22,7 @@ import bpy
import sympy as sp
import tidy3d as td
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py
index 91faec1..1490bdd 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py
@@ -22,7 +22,7 @@ import bpy
import sympy as sp
import tidy3d as td
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
index aa94c09..9a94de8 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
@@ -19,7 +19,7 @@ import inspect
import typing as typ
from types import MappingProxyType
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
index dba473a..11068e3 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
@@ -21,7 +21,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
index db9df67..213746a 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
@@ -16,12 +16,13 @@
import enum
import typing as typ
+from fractions import Fraction
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@@ -50,6 +51,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
)
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
+ ## Use of NumberSize1D implicitly guarantees UI-realizability later.
+
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
@@ -109,31 +112,110 @@ class SymbolConstantNode(base.MaxwellSimNode):
preview_value_re: float = bl_cache.BLField(0.0)
preview_value_im: float = bl_cache.BLField(0.0)
- ####################
- # - Computed Properties
- ####################
@bl_cache.cached_bl_property(
depends_on={
- 'sym_name',
- 'size',
'mathtype',
- 'physical_type',
- 'unit',
'interval_finite_z',
'interval_finite_q',
'interval_finite_re',
- 'interval_inf',
- 'interval_closed',
'interval_finite_im',
- 'interval_inf_im',
- 'interval_closed_im',
+ }
+ )
+ def interval_finite(
+ self,
+ ) -> (
+ tuple[int | Fraction | float, int | Fraction | float]
+ | tuple[tuple[float, float], tuple[float, float]]
+ ):
+ """Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
+ MT = spux.MathType
+ match self.mathtype:
+ case MT.Integer:
+ return self.interval_finite_z
+ case MT.Rational:
+ return [Fraction(*q) for q in self.interval_finite_q]
+ case MT.Real:
+ return self.interval_finite_re
+ case MT.Complex:
+ return (self.interval_finite_re, self.interval_finite_im)
+
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'mathtype',
'preview_value_z',
'preview_value_q',
'preview_value_re',
'preview_value_im',
}
)
+ def preview_value(
+ self,
+ ) -> int | Fraction | float | complex:
+ """Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
+ MT = spux.MathType
+ match self.mathtype:
+ case MT.Integer:
+ return self.preview_value_z
+ case MT.Rational:
+ return Fraction(*self.preview_value_q)
+ case MT.Real:
+ return self.preview_value_re
+ case MT.Complex:
+ return complex(self.preview_value_re, self.preview_value_im)
+
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'mathtype',
+ 'interval_finite',
+ 'interval_inf',
+ 'interval_inf_im',
+ 'interval_closed',
+ 'interval_closed_im',
+ }
+ )
+ def domain(
+ self,
+ ) -> sp.Interval | sp.sets.fancysets.CartesianComplexRegion:
+ """Deduce the domain specified in the UI."""
+ MT = spux.MathType
+ match self.mathtype:
+ case MT.Integer | MT.Real | MT.Rational:
+ return sim_symbols.mk_interval(
+ self.interval_finite,
+ self.interval_inf,
+ self.interval_closed,
+ )
+
+ case MT.Complex:
+ region = self.interval_finite
+ domain_re = sim_symbols.mk_interval(
+ region[0],
+ self.interval_inf,
+ self.interval_closed,
+ )
+ domain_im = sim_symbols.mk_interval(
+ region[1],
+ self.interval_inf_im,
+ self.interval_closed_im,
+ )
+ return sp.ComplexRegion(domain_re, domain_im, polar=False)
+
+ ####################
+ # - Computed Properties
+ ####################
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'sym_name',
+ 'mathtype',
+ 'physical_type',
+ 'unit',
+ 'size',
+ 'domain',
+ 'preview_value',
+ }
+ )
def symbol(self) -> sim_symbols.SimSymbol:
+ """Generate the `SimSymbol` matching the user-specification."""
return sim_symbols.SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
@@ -141,18 +223,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
- interval_finite_z=self.interval_finite_z,
- interval_finite_q=self.interval_finite_q,
- interval_finite_re=self.interval_finite_re,
- interval_inf=self.interval_inf,
- interval_closed=self.interval_closed,
- interval_finite_im=self.interval_finite_im,
- interval_inf_im=self.interval_inf_im,
- interval_closed_im=self.interval_closed_im,
- preview_value_z=self.preview_value_z,
- preview_value_q=self.preview_value_q,
- preview_value_re=self.preview_value_re,
- preview_value_im=self.preview_value_im,
+ domain=self.domain,
+ preview_value=self.preview_value,
)
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
index 252dbed..e9134fc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
@@ -23,7 +23,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py
index 2acafa7..8b6fc43 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py
@@ -23,7 +23,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py
index 68d6d49..55f96c6 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py
@@ -23,7 +23,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger, sci_constants
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
index da00004..e9c275b 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
@@ -25,7 +25,7 @@ from tidy3d.material_library.material_library import MaterialItem as Tidy3DMediu
from tidy3d.material_library.material_library import VariantItem as Tidy3DMediumVariant
from blender_maxwell.utils import bl_cache, logger, sci_constants
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
index 7d40ba2..d9ca2e8 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
@@ -23,7 +23,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
index 319ce0b..d40e0ab 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
@@ -23,7 +23,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
index b5a287c..85d845b 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
@@ -20,7 +20,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py
index 4cf1503..963186b 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py
@@ -20,7 +20,7 @@ from pathlib import Path
import bpy
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
index 585793b..0ad5d02 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
@@ -21,7 +21,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
index 0197f5f..7c525ff 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
@@ -22,7 +22,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
index 2bbd6b0..94eb827 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
@@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
index d1d3ab5..3f60b9e 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
@@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
index 0ed364b..0457790 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
@@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
index 1f48f31..c307143 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
@@ -28,7 +28,7 @@ from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
index 27369d5..6669a50 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
@@ -20,7 +20,7 @@ import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import bl_socket_map, managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
index 2109185..754ac29 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
@@ -24,7 +24,7 @@ import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
index 072756d..1177dbd 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
@@ -21,7 +21,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
index d5a83b4..8a51183 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
@@ -21,7 +21,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
index 84e077b..6c553a7 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
@@ -24,7 +24,7 @@ import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
from . import base
@@ -155,12 +155,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
+ is_constant=True,
+ ## TODO: Should we set preview values
exclude_zero=(
not self.value.is_zero
if self.value.is_zero is not None
else False
),
- ## TODO: Does this work for matrix elements?
+ ## TODO: Does this 0-check work for matrix elements?
)
case ct.FlowKind.Range if self.symbols:
@@ -208,7 +210,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
sim_symbols.SimSymbolName.Expr
)
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
- sim_symbols.SimSymbolName.Expr
+ sim_symbols.SimSymbolName.Constant
)
symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
@@ -441,12 +443,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False):
"""Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`."""
- if self.unit is not None:
- pyvalue = spux.sympy_to_python(spux.scale_to_unit(expr, self.unit))
- else:
- pyvalue = spux.sympy_to_python(expr)
+ pyvalue = spux.scale_to_unit(expr, self.unit)
# Cast complex -> tuple[float, float]
+ ## -> We can't set complex to BLProps.
+ ## -> We must deconstruct it appropriately.
if isinstance(pyvalue, complex) or (
isinstance(pyvalue, int | float) and force_complex
):
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
index 2c2119f..aa41351 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
@@ -24,7 +24,7 @@ import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from .. import base
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py
index 7cf5cf0..a39fcbf 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py
@@ -19,7 +19,7 @@ import sympy as sp
import sympy.physics.optics.polarization as spo_pol
import sympy.physics.units as spu
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from .. import base
diff --git a/src/blender_maxwell/utils/__init__.py b/src/blender_maxwell/utils/__init__.py
index a44be30..e1e2d51 100644
--- a/src/blender_maxwell/utils/__init__.py
+++ b/src/blender_maxwell/utils/__init__.py
@@ -17,22 +17,22 @@
from ..nodeps.utils import blender_type_enum, pydeps
from . import (
bl_cache,
- extra_sympy_units,
image_ops,
logger,
sci_constants,
serialize,
staticproperty,
+ sympy_extra,
)
__all__ = [
'blender_type_enum',
'pydeps',
'bl_cache',
- 'extra_sympy_units',
'image_ops',
'logger',
'sci_constants',
'serialize',
'staticproperty',
+ 'sympy_extra',
]
diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py
deleted file mode 100644
index fa31709..0000000
--- a/src/blender_maxwell/utils/extra_sympy_units.py
+++ /dev/null
@@ -1,1699 +0,0 @@
-# blender_maxwell
-# Copyright (C) 2024 blender_maxwell Project Contributors
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-"""Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system.
-
-Attributes:
- UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit.
- This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`.
- SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions.
- For simple `isinstance` checks, this should be preferred, as it is most performant.
- For general use, `SympyExpr` should be preferred.
- SympyExpr: A `SympyType` that is compatible with `pydantic`, including serialization/deserialization.
- Should be used via the `ConstrSympyExpr`, which also adds expression validation.
-"""
-
-import enum
-import functools
-import sys
-import typing as typ
-from fractions import Fraction
-
-import jax
-import jax.numpy as jnp
-import jaxtyping as jtyp
-import pydantic as pyd
-import sympy as sp
-import sympy.physics.units as spu
-import typing_extensions as typx
-from pydantic_core import core_schema as pyd_core_schema
-
-from blender_maxwell import contracts as ct
-
-from . import logger
-from .staticproperty import staticproperty
-
-log = logger.get(__name__)
-
-SympyType = (
- sp.Basic
- | sp.Expr
- | sp.MatrixBase
- | sp.MutableDenseMatrix
- | spu.Quantity
- | spu.Dimension
-)
-
-
-####################
-# - Math Type
-####################
-class MathType(enum.StrEnum):
- """Type identifiers that encompass common sets of mathematical objects."""
-
- Integer = enum.auto()
- Rational = enum.auto()
- Real = enum.auto()
- Complex = enum.auto()
-
- @staticmethod
- def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None:
- if MathType.Complex in mathtypes:
- return MathType.Complex
- if MathType.Real in mathtypes:
- return MathType.Real
- if MathType.Rational in mathtypes:
- return MathType.Rational
- if MathType.Integer in mathtypes:
- return MathType.Integer
-
- if optional:
- return None
-
- msg = f"Can't combine mathtypes {mathtypes}"
- raise ValueError(msg)
-
- def is_compatible(self, other: typ.Self) -> bool:
- MT = MathType
- return (
- other
- in {
- MT.Integer: [MT.Integer],
- MT.Rational: [MT.Integer, MT.Rational],
- MT.Real: [MT.Integer, MT.Rational, MT.Real],
- MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex],
- }[self]
- )
-
- def coerce_compatible_pyobj(
- self, pyobj: bool | int | Fraction | float | complex
- ) -> int | Fraction | float | complex:
- MT = MathType
- match self:
- case MT.Integer:
- return int(pyobj)
- case MT.Rational if isinstance(pyobj, int):
- return Fraction(pyobj, 1)
- case MT.Rational if isinstance(pyobj, Fraction):
- return pyobj
- case MT.Real:
- return float(pyobj)
- case MT.Complex if isinstance(pyobj, int | Fraction):
- return complex(float(pyobj), 0)
- case MT.Complex if isinstance(pyobj, float):
- return complex(pyobj, 0)
-
- @staticmethod
- def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None:
- if isinstance(sp_obj, sp.MatrixBase):
- return MathType.combine(
- *[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
- )
-
- if sp_obj.is_integer:
- return MathType.Integer
- if sp_obj.is_rational:
- return MathType.Rational
- if sp_obj.is_real:
- return MathType.Real
- if sp_obj.is_complex:
- return MathType.Complex
-
- # Infinities
- if sp_obj in [sp.oo, -sp.oo]:
- return MathType.Real ## TODO: Strictly, could be ex. integer...
- if sp_obj in [sp.zoo, -sp.zoo]:
- return MathType.Complex
-
- if optional:
- return None
-
- msg = f"Can't determine MathType from sympy object: {sp_obj}"
- raise ValueError(msg)
-
- @staticmethod
- def from_pytype(dtype: type) -> type:
- return {
- int: MathType.Integer,
- Fraction: MathType.Rational,
- float: MathType.Real,
- complex: MathType.Complex,
- }[dtype]
-
- @staticmethod
- def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
- """Deduce the MathType corresponding to a JAX array.
-
- We go about this by leveraging that:
- - `data` is of a homogeneous type.
- - `data.item(0)` returns a single element of the array w/pure-python type.
-
- By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
-
- Notes:
- Should also work with numpy arrays.
- """
- return MathType.from_pytype(type(data.item(0)))
-
- @staticmethod
- def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
- if isinstance(obj, bool | int | Fraction | float | complex):
- return 'pytype'
- if isinstance(obj, sp.Basic | sp.MatrixBase | sp.MutableDenseMatrix):
- return 'expr'
-
- return None
-
- @property
- def pytype(self) -> type:
- MT = MathType
- return {
- MT.Integer: int,
- MT.Rational: Fraction,
- MT.Real: float,
- MT.Complex: complex,
- }[self]
-
- @property
- def symbolic_set(self) -> type:
- MT = MathType
- return {
- MT.Integer: sp.Integers,
- MT.Rational: sp.Rationals,
- MT.Real: sp.Reals,
- MT.Complex: sp.Complexes,
- }[self]
-
- @property
- def inf_finite(self) -> type:
- """Opinionated finite representation of "infinity" within this `MathType`.
-
- These are chosen using `sys.maxsize` and `sys.float_info`.
- As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
-
- **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
-
- Notes:
- Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
-
- These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
- In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
-
- However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
- """
- MT = MathType
- Z = MT.Integer
- R = MT.Integer
- return {
- MT.Integer: (-sys.maxsize, sys.maxsize),
- MT.Rational: (
- Fraction(Z.inf_finite[0], 1),
- Fraction(Z.inf_finite[1], 1),
- ),
- MT.Real: -(sys.float_info.min, sys.float_info.max),
- MT.Complex: (
- complex(R.inf_finite[0], R.inf_finite[0]),
- complex(R.inf_finite[1], R.inf_finite[1]),
- ),
- }[self]
-
- @property
- def sp_symbol_a(self) -> type:
- MT = MathType
- return {
- MT.Integer: sp.Symbol('a', integer=True),
- MT.Rational: sp.Symbol('a', rational=True),
- MT.Real: sp.Symbol('a', real=True),
- MT.Complex: sp.Symbol('a', complex=True),
- }[self]
-
- @staticmethod
- def to_str(value: typ.Self) -> type:
- return {
- MathType.Integer: 'ℤ',
- MathType.Rational: 'ℚ',
- MathType.Real: 'ℝ',
- MathType.Complex: 'ℂ',
- }[value]
-
- @property
- def label_pretty(self) -> str:
- return MathType.to_str(self)
-
- @staticmethod
- def to_name(value: typ.Self) -> str:
- return MathType.to_str(value)
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- return (
- str(self),
- MathType.to_name(self),
- MathType.to_name(self),
- MathType.to_icon(self),
- i,
- )
-
-
-####################
-# - Size: 1D
-####################
-class NumberSize1D(enum.StrEnum):
- """Valid 1D-constrained shape."""
-
- Scalar = enum.auto()
- Vec2 = enum.auto()
- Vec3 = enum.auto()
- Vec4 = enum.auto()
-
- @staticmethod
- def to_name(value: typ.Self) -> str:
- NS = NumberSize1D
- return {
- NS.Scalar: 'Scalar',
- NS.Vec2: '2D',
- NS.Vec3: '3D',
- NS.Vec4: '4D',
- }[value]
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- NS = NumberSize1D
- return {
- NS.Scalar: '',
- NS.Vec2: '',
- NS.Vec3: '',
- NS.Vec4: '',
- }[value]
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- return (
- str(self),
- NumberSize1D.to_name(self),
- NumberSize1D.to_name(self),
- NumberSize1D.to_icon(self),
- i,
- )
-
- @staticmethod
- def has_shape(shape: tuple[int, ...] | None):
- return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)]
-
- def supports_shape(self, shape: tuple[int, ...] | None):
- NS = NumberSize1D
- match self:
- case NS.Scalar:
- return shape is None
- case NS.Vec2:
- return shape in ((2,), (2, 1))
- case NS.Vec3:
- return shape in ((3,), (3, 1))
- case NS.Vec4:
- return shape in ((4,), (4, 1))
-
- @staticmethod
- def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self:
- NS = NumberSize1D
- return {
- None: NS.Scalar,
- (2,): NS.Vec2,
- (3,): NS.Vec3,
- (4,): NS.Vec4,
- (2, 1): NS.Vec2,
- (3, 1): NS.Vec3,
- (4, 1): NS.Vec4,
- }[shape]
-
- @property
- def rows(self):
- NS = NumberSize1D
- return {
- NS.Scalar: 1,
- NS.Vec2: 2,
- NS.Vec3: 3,
- NS.Vec4: 4,
- }[self]
-
- @property
- def cols(self):
- return 1
-
- @property
- def shape(self):
- NS = NumberSize1D
- return {
- NS.Scalar: None,
- NS.Vec2: (2,),
- NS.Vec3: (3,),
- NS.Vec4: (4,),
- }[self]
-
-
-def symbol_range(sym: sp.Symbol) -> str:
- return f'{sym.name} ∈ ' + (
- 'ℂ'
- if sym.is_complex
- else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?'))
- )
-
-
-####################
-# - Symbol Sizes
-####################
-class SimpleSize2D(enum.StrEnum):
- """Simple subset of sizes for rank-2 tensors."""
-
- Scalar = enum.auto()
-
- # Vectors
- Vec2 = enum.auto() ## 2x1
- Vec3 = enum.auto() ## 3x1
- Vec4 = enum.auto() ## 4x1
-
- # Covectors
- CoVec2 = enum.auto() ## 1x2
- CoVec3 = enum.auto() ## 1x3
- CoVec4 = enum.auto() ## 1x4
-
- # Square Matrices
- Mat22 = enum.auto() ## 2x2
- Mat33 = enum.auto() ## 3x3
- Mat44 = enum.auto() ## 4x4
-
-
-####################
-# - Unit Dimensions
-####################
-class DimsMeta(type):
- def __getattr__(cls, attr: str) -> spu.Dimension:
- if (
- attr in spu.definitions.dimension_definitions.__dir__()
- and not attr.startswith('__')
- ):
- return getattr(spu.definitions.dimension_definitions, attr)
-
- raise AttributeError(name=attr, obj=Dims)
-
-
-class Dims(metaclass=DimsMeta):
- """Access `sympy.physics.units` dimensions with less hassle.
-
- Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`.
-
- An `AttributeError` is raised if the unit cannot be found in `sympy`.
-
- Examples:
- The objects returned are a direct alias to `sympy`, with less hassle:
- ```python
- assert Dims.length == (
- sympy.physics.units.definitions.dimension_definitions.length
- )
- ```
- """
-
-
-####################
-# - Units
-####################
-femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs')
-femtosecond.set_global_relative_scale_factor(spu.femto, spu.second)
-
-# Length
-femtometer = fm = spu.Quantity('femtometer', abbrev='fm')
-femtometer.set_global_relative_scale_factor(spu.femto, spu.meter)
-
-# Lum Flux
-lumen = lm = spu.Quantity('lumen', abbrev='lm')
-lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian)
-
-# Force
-nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816
-nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton)
-
-micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816
-micronewton.set_global_relative_scale_factor(spu.micro, spu.newton)
-
-millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816
-micronewton.set_global_relative_scale_factor(spu.milli, spu.newton)
-
-# Frequency
-kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz')
-kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
-
-megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz')
-kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
-
-gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz')
-gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz)
-
-terahertz = THz = spu.Quantity('terahertz', abbrev='THz')
-terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz)
-
-petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz')
-petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz)
-
-exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz')
-exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz)
-
-# Pressure
-millibar = mbar = spu.Quantity('millibar', abbrev='mbar')
-millibar.set_global_relative_scale_factor(spu.milli, spu.bar)
-
-hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816
-hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal)
-
-UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
- unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
-} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
-
-UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
-
-
-####################
-# - Expr Analysis: Units
-####################
-## TODO: Caching w/srepr'ed expression.
-## TODO: An LFU cache could do better than an LRU.
-def uses_units(sp_obj: SympyType) -> bool:
- """Determines if an expression uses any units.
-
- Notes:
- The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
- Depth-first was chosen since `sp.Quantity`s are likelier to be found among individual symbols, rather than complete subexpressions.
-
- The **worst-case** runtime is when there are no units, in which case the **entire expression graph will be traversed**.
-
- Parameters:
- expr: The sympy expression that may contain units.
-
- Returns:
- Whether or not there are units used within the expression.
- """
- return sp_obj.has(spu.Quantity)
- # return any(
- # isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj)
- # )
-
-
-## TODO: Caching w/srepr'ed expression.
-## TODO: An LFU cache could do better than an LRU.
-def get_units(expr: sp.Expr) -> set[spu.Quantity]:
- """Finds all units used by the expression, and returns them as a set.
-
- No information about _the relationship between units_ is exposed.
- For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`.
-
-
- Notes:
- The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
-
- The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node.
-
- Parameters:
- expr: The sympy expression that may contain units.
-
- Returns:
- All units (`spu.Quantity`) used within the expression.
- """
- return {
- subexpr
- for subexpr in sp.postorder_traversal(expr)
- if isinstance(subexpr, spu.Quantity)
- }
-
-
-def parse_shape(sp_obj: SympyType) -> int | None:
- if isinstance(sp_obj, sp.MatrixBase):
- return sp_obj.shape
-
- return None
-
-
-####################
-# - Pydantic-Validated SympyExpr
-####################
-class _SympyExpr:
- """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`.
-
- Notes:
- You probably want to use `SympyExpr`.
-
- Examples:
- To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`:
-
- ```python
- SympyExpr = typx.Annotated[SympyType, _SympyExpr]
-
- class Spam(pyd.BaseModel):
- line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3)
- ```
- """
-
- @classmethod
- def __get_pydantic_core_schema__(
- cls,
- _source_type: SympyType,
- _handler: pyd.GetCoreSchemaHandler,
- ) -> pyd_core_schema.CoreSchema:
- """Compute a schema that allows `pydantic` to validate a `sympy` type."""
-
- def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any:
- """Parse and validate a string expression.
-
- Parameters:
- sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
- Before use, `isinstance(expr_str, str)` is checked.
- If the object isn't a string, then the validation will be skipped.
-
- Returns:
- Either a `sympy` object, if the input is parseable, or the same untouched object.
-
- Raises:
- ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
- """
- # Constrain to String
- if not isinstance(sp_str, str):
- return sp_str
-
- # Parse String -> Sympy
- try:
- expr = sp.sympify(sp_str)
- except ValueError as ex:
- msg = f'String {sp_str} is not a valid sympy expression'
- raise ValueError(msg) from ex
-
- # Substitute Symbol -> Quantity
- return expr.subs(UNIT_BY_SYMBOL)
-
- def validate_from_pytype(
- sp_pytype: int | Fraction | float | complex,
- ) -> SympyType | typ.Any:
- """Parse and validate a pure Python type.
-
- Parameters:
- sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
- Before use, `isinstance(expr_str, str)` is checked.
- If the object isn't a string, then the validation will be skipped.
-
- Returns:
- Either a `sympy` object, if the input is parseable, or the same untouched object.
-
- Raises:
- ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
- """
- # Constrain to String
- if not isinstance(sp_pytype, int | Fraction | float | complex):
- return sp_pytype
-
- if isinstance(sp_pytype, int):
- return sp.Integer(sp_pytype)
- if isinstance(sp_pytype, Fraction):
- return sp.Rational(sp_pytype.numerator, sp_pytype.denominator)
- if isinstance(sp_pytype, float):
- return sp.Float(sp_pytype)
-
- # sp_pytype => Complex
- return sp_pytype.real + sp.I * sp_pytype.imag
-
- sympy_expr_schema = pyd_core_schema.chain_schema(
- [
- pyd_core_schema.no_info_plain_validator_function(validate_from_str),
- pyd_core_schema.no_info_plain_validator_function(validate_from_pytype),
- pyd_core_schema.is_instance_schema(SympyType),
- ]
- )
- return pyd_core_schema.json_or_python_schema(
- json_schema=sympy_expr_schema,
- python_schema=sympy_expr_schema,
- serialization=pyd_core_schema.plain_serializer_function_ser_schema(
- lambda sp_obj: sp.srepr(sp_obj)
- ),
- )
-
-
-SympyExpr = typx.Annotated[
- sp.Basic, ## Treat all sympy types as sp.Basic
- _SympyExpr,
-]
-## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate.
-
-
-def ConstrSympyExpr( # noqa: N802, PLR0913
- # Features
- allow_variables: bool = True,
- allow_units: bool = True,
- # Structures
- allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
- | None = None,
- allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
- # Element Class
- max_symbols: int | None = None,
- allowed_symbols: set[sp.Symbol] | None = None,
- allowed_units: set[spu.Quantity] | None = None,
- # Shape Class
- allowed_matrix_shapes: set[tuple[int, int]] | None = None,
-) -> SympyType:
- """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`.
-
- Relies on the `sympy` assumptions system.
- See
-
- Parameters (TBD):
-
- Returns:
- A type that represents a constrained `sympy` expression.
- """
-
- def validate_expr(expr: SympyType):
- if not (isinstance(expr, SympyType),):
- msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})"
- raise ValueError(msg)
-
- msgs = set()
-
- # Validate Feature Class
- if (not allow_variables) and (len(expr.free_symbols) > 0):
- msgs.add(
- f'allow_variables={allow_variables} does not match expression {expr}.'
- )
- if (not allow_units) and uses_units(expr):
- msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
-
- # Validate Structure Class
- if (
- allowed_sets
- and isinstance(expr, sp.Expr)
- and not any(
- {
- 'integer': expr.is_integer,
- 'rational': expr.is_rational,
- 'real': expr.is_real,
- 'complex': expr.is_complex,
- }[allowed_set]
- for allowed_set in allowed_sets
- )
- ):
- msgs.add(
- f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
- )
- if allowed_structures and not any(
- {
- 'scalar': True,
- 'matrix': isinstance(expr, sp.MatrixBase),
- }[allowed_set]
- for allowed_set in allowed_structures
- ):
- msgs.add(
- f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
- )
-
- # Validate Element Class
- if max_symbols and len(expr.free_symbols) > max_symbols:
- msgs.add(f'max_symbols={max_symbols} does not match expression {expr}')
- if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
- msgs.add(
- f'allowed_symbols={allowed_symbols} does not match expression {expr}'
- )
- if allowed_units and get_units(expr).issubset(allowed_units):
- msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
-
- # Validate Shape Class
- if (
- allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
- ) and expr.shape not in allowed_matrix_shapes:
- msgs.add(
- f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
- )
-
- # Error or Return
- if msgs:
- raise ValueError(str(msgs))
- return expr
-
- return typx.Annotated[
- sp.Basic,
- _SympyExpr,
- pyd.AfterValidator(validate_expr),
- ]
-
-
-####################
-# - Common ConstrSympyExpr
-####################
-# Expression
-ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_structures={'scalar'},
- allowed_sets={'integer', 'rational', 'real'},
-)
-ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_structures={'scalar'},
- allowed_sets={'integer', 'rational', 'real', 'complex'},
-)
-
-# Symbol
-IntSymbol: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=True,
- allow_units=False,
- allowed_sets={'integer'},
- max_symbols=1,
-)
-RationalSymbol: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=True,
- allow_units=False,
- allowed_sets={'integer', 'rational'},
- max_symbols=1,
-)
-RealSymbol: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=True,
- allow_units=False,
- allowed_sets={'integer', 'rational', 'real'},
- max_symbols=1,
-)
-ComplexSymbol: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=True,
- allow_units=False,
- allowed_sets={'integer', 'rational', 'real', 'complex'},
- max_symbols=1,
-)
-Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol
-
-# Unit
-UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension
-
-## Technically a "unit expression", which includes compound types.
-## Support for this is the reason to prefer over raw spu.Quantity.
-Unit: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=True,
- allowed_structures={'scalar'},
-)
-
-# Number
-IntNumber: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_sets={'integer'},
- allowed_structures={'scalar'},
-)
-RealNumber: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_sets={'integer', 'rational', 'real'},
- allowed_structures={'scalar'},
-)
-ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_sets={'integer', 'rational', 'real', 'complex'},
- allowed_structures={'scalar'},
-)
-Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
-
-# Number
-PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=True,
- allowed_sets={'integer', 'rational', 'real'},
- allowed_structures={'scalar'},
-)
-PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=True,
- allowed_sets={'integer', 'rational', 'real', 'complex'},
- allowed_structures={'scalar'},
-)
-PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
-
-# Vector
-Real3DVector: typ.TypeAlias = ConstrSympyExpr(
- allow_variables=False,
- allow_units=False,
- allowed_sets={'integer', 'rational', 'real'},
- allowed_structures={'matrix'},
- allowed_matrix_shapes={(3, 1)},
-)
-
-
-####################
-# - Sympy Utilities: Printing
-####################
-_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter(
- settings={
- 'abbrev': True,
- }
-)
-
-
-def sp_to_str(sp_obj: SympyExpr) -> str:
- """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter.
-
- This should be used whenever a **string for UI use** is needed from a `sympy` object.
-
- Notes:
- This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression.
- For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation.
-
- Parameters:
- sp_obj: The `sympy` object to convert to a string.
-
- Returns:
- A string representing the expression for human use.
- _The string is not re-encodable to the expression._
- """
- ## TODO: A bool flag property that does a lot of find/replace to make it super pretty
- return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
-
-
-def pretty_symbol(sym: sp.Symbol) -> str:
- return f'{sym.name} ∈ ' + (
- 'ℤ'
- if sym.is_integer
- else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?'))
- )
-
-
-####################
-# - Unit Utilities
-####################
-def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number:
- """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value.
-
- This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**.
-
- Notes:
- The unitless output is still an `sp.Expr`, which may contain ex. symbols.
-
- If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type.
- In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem.
-
- Parameters:
- expr: The unit-containing expression to convert.
- unit_to: The unit that is converted to.
-
- Returns:
- The unitless part of `expr`, after scaling the entire expression to `unit`.
-
- Raises:
- ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`.
- """
- unitless_expr = spu.convert_to(sp_obj, unit) / unit if unit is not None else sp_obj
- if not uses_units(unitless_expr):
- return unitless_expr
-
- msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"'
- raise ValueError(msg)
-
-
-def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number:
- """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another.
-
- Parameters:
- unit_from: The unit that is converted from.
- unit_to: The unit that is converted to.
-
- Returns:
- The numerical scaling factor between the two units.
-
- Raises:
- ValueError: If the two units don't share a common dimension.
- """
- if unit_from.dimension == unit_to.dimension:
- return scale_to_unit(unit_from, unit_to)
-
- msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them"
- raise ValueError(msg)
-
-
-@functools.cache
-def unit_str_to_unit(unit_str: str) -> Unit | None:
- # Edge Case: Manually Parse Degrees
- ## -> sp.sympify('degree') actually produces the sp.degree() function.
- ## -> Therefore, we must special case this particular unit.
- if unit_str == 'degree':
- expr = spu.degree
- else:
- expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL)
-
- if expr.has(spu.Quantity):
- return expr
-
- msg = f'No valid unit for unit string {unit_str}'
- raise ValueError(msg)
-
-
-####################
-# - "Physical" Type
-####################
-def unit_dim_to_unit_dim_deps(
- unit_dims: SympyType,
-) -> dict[spu.dimensions.Dimension, int] | None:
- dimsys_SI = spu.systems.si.dimsys_SI
-
- # Retrieve Dimensional Dependencies
- try:
- return dimsys_SI.get_dimensional_dependencies(unit_dims)
-
- # Catch TypeError
- ## -> Happens if `+` or `-` is in `unit`.
- ## -> Generally, it doesn't make sense to add/subtract differing unit dims.
- ## -> Thus, when trying to figure out the unit dimension, there isn't one.
- except TypeError:
- return None
-
-
-def unit_to_unit_dim_deps(
- unit: SympyType,
-) -> dict[spu.dimensions.Dimension, int] | None:
- # Retrieve Dimensional Dependencies
- ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms.
- ## -> This is extremely problematic; `Dims` arithmetic has key properties.
- ## -> So we have to go all the way to the dimensional dependencies.
- ## -> This isn't really respecting the args, but it seems to work :)
- return unit_dim_to_unit_dim_deps(
- unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)})
- )
-
-
-def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
- return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps(
- unit_dim_r
- )
-
-
-def compare_unit_dim_to_unit_dim_deps(
- unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int]
-) -> bool:
- return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps
-
-
-class PhysicalType(enum.StrEnum):
- """Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type."""
-
- # Unitless
- NonPhysical = enum.auto()
-
- # Global
- Time = enum.auto()
- Angle = enum.auto()
- SolidAngle = enum.auto()
- ## TODO: Some kind of 3D-specific orientation ex. a quaternion
- Freq = enum.auto()
- AngFreq = enum.auto() ## rad*hertz
- # Cartesian
- Length = enum.auto()
- Area = enum.auto()
- Volume = enum.auto()
- # Mechanical
- Vel = enum.auto()
- Accel = enum.auto()
- Mass = enum.auto()
- Force = enum.auto()
- Pressure = enum.auto()
- # Energy
- Work = enum.auto() ## joule
- Power = enum.auto() ## watt
- PowerFlux = enum.auto() ## watt
- Temp = enum.auto()
- # Electrodynamics
- Current = enum.auto() ## ampere
- CurrentDensity = enum.auto()
- Charge = enum.auto() ## coulomb
- Voltage = enum.auto()
- Capacitance = enum.auto() ## farad
- Impedance = enum.auto() ## ohm
- Conductance = enum.auto() ## siemens
- Conductivity = enum.auto() ## siemens / length
- MFlux = enum.auto() ## weber
- MFluxDensity = enum.auto() ## tesla
- Inductance = enum.auto() ## henry
- EField = enum.auto()
- HField = enum.auto()
- # Luminal
- LumIntensity = enum.auto()
- LumFlux = enum.auto()
- Illuminance = enum.auto()
-
- @functools.cached_property
- def unit_dim(self) -> SympyType:
- PT = PhysicalType
- return {
- PT.NonPhysical: None,
- # Global
- PT.Time: Dims.time,
- PT.Angle: Dims.angle,
- PT.SolidAngle: spu.steradian.dimension, ## MISSING
- PT.Freq: Dims.frequency,
- PT.AngFreq: Dims.angle * Dims.frequency,
- # Cartesian
- PT.Length: Dims.length,
- PT.Area: Dims.length**2,
- PT.Volume: Dims.length**3,
- # Mechanical
- PT.Vel: Dims.length / Dims.time,
- PT.Accel: Dims.length / Dims.time**2,
- PT.Mass: Dims.mass,
- PT.Force: Dims.force,
- PT.Pressure: Dims.pressure,
- # Energy
- PT.Work: Dims.energy,
- PT.Power: Dims.power,
- PT.PowerFlux: Dims.power / Dims.length**2,
- PT.Temp: Dims.temperature,
- # Electrodynamics
- PT.Current: Dims.current,
- PT.CurrentDensity: Dims.current / Dims.length**2,
- PT.Charge: Dims.charge,
- PT.Voltage: Dims.voltage,
- PT.Capacitance: Dims.capacitance,
- PT.Impedance: Dims.impedance,
- PT.Conductance: Dims.conductance,
- PT.Conductivity: Dims.conductance / Dims.length,
- PT.MFlux: Dims.magnetic_flux,
- PT.MFluxDensity: Dims.magnetic_density,
- PT.Inductance: Dims.inductance,
- PT.EField: Dims.voltage / Dims.length,
- PT.HField: Dims.current / Dims.length,
- # Luminal
- PT.LumIntensity: Dims.luminous_intensity,
- PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
- PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
- }[self]
-
- @staticproperty
- def unit_dims() -> dict[typ.Self, SympyType]:
- return {
- physical_type: physical_type.unit_dim
- for physical_type in list(PhysicalType)
- }
-
- @functools.cached_property
- def color(self):
- """A color corresponding to the physical type.
-
- The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented.
- The LLM provided the following rationale for its choices:
-
- > Non-Physical: Grey signifies neutrality and non-physical nature.
- > Global:
- > Time: Blue is often associated with calmness and the passage of time.
- > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects.
- > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time.
- > Cartesian:
- > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension.
- > Mechanical:
- > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities.
- > Mass: Dark red for the fundamental property.
- > Force and Pressure: Shades of red indicating intensity.
- > Energy:
- > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities.
- > Temperature: Yellow for heat.
- > Electrodynamics:
- > Current and related quantities: Cyan shades indicating flow.
- > Voltage, Capacitance: Greenish and blueish cyan for electrical potential.
- > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance.
- > Magnetic properties: Magenta shades for magnetism.
- > Electric Field: Light blue.
- > Magnetic Field: Grey, as it can be considered neutral in terms of direction.
- > Luminal:
- > Luminous properties: Yellows to signify light and illumination.
- >
- > This color mapping helps maintain intuitive connections for users interacting with these physical types.
- """
- PT = PhysicalType
- return {
- PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical
- # Global
- PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time
- PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle
- PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle
- PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency
- PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency
- # Cartesian
- PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length
- PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area
- PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume
- # Mechanical
- PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity
- PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration
- PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass
- PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force
- PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure
- # Energy
- PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work
- PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power
- PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux
- PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature
- # Electrodynamics
- PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current
- PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density
- PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge
- PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage
- PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance
- PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance
- PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance
- PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity
- PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux
- PT.MFluxDensity: (
- 0.85,
- 0.5,
- 0.85,
- 1.0,
- ), # Light Magenta: Magnetic Flux Density
- PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance
- PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field
- PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field
- # Luminal
- PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity
- PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux
- PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance
- }[self]
-
- @functools.cached_property
- def default_unit(self) -> list[Unit]:
- PT = PhysicalType
- return {
- PT.NonPhysical: None,
- # Global
- PT.Time: spu.picosecond,
- PT.Angle: spu.radian,
- PT.SolidAngle: spu.steradian,
- PT.Freq: terahertz,
- PT.AngFreq: spu.radian * terahertz,
- # Cartesian
- PT.Length: spu.micrometer,
- PT.Area: spu.um**2,
- PT.Volume: spu.um**3,
- # Mechanical
- PT.Vel: spu.um / spu.second,
- PT.Accel: spu.um / spu.second,
- PT.Mass: spu.microgram,
- PT.Force: micronewton,
- PT.Pressure: millibar,
- # Energy
- PT.Work: spu.joule,
- PT.Power: spu.watt,
- PT.PowerFlux: spu.watt / spu.meter**2,
- PT.Temp: spu.kelvin,
- # Electrodynamics
- PT.Current: spu.ampere,
- PT.CurrentDensity: spu.ampere / spu.meter**2,
- PT.Charge: spu.coulomb,
- PT.Voltage: spu.volt,
- PT.Capacitance: spu.farad,
- PT.Impedance: spu.ohm,
- PT.Conductance: spu.siemens,
- PT.Conductivity: spu.siemens / spu.micrometer,
- PT.MFlux: spu.weber,
- PT.MFluxDensity: spu.tesla,
- PT.Inductance: spu.henry,
- PT.EField: spu.volt / spu.micrometer,
- PT.HField: spu.ampere / spu.micrometer,
- # Luminal
- PT.LumIntensity: spu.candela,
- PT.LumFlux: spu.candela * spu.steradian,
- PT.Illuminance: spu.candela / spu.meter**2,
- }[self]
-
- @functools.cached_property
- def valid_units(self) -> list[Unit]:
- """Retrieve an ordered (by subjective usefulness) list of units for this physical type.
-
- Notes:
- The order in which valid units are declared is the exact same order that UI dropdowns display them.
-
- **Altering the order of units breaks backwards compatibility**.
- """
- PT = PhysicalType
- return {
- PT.NonPhysical: [None],
- # Global
- PT.Time: [
- spu.picosecond,
- femtosecond,
- spu.nanosecond,
- spu.microsecond,
- spu.millisecond,
- spu.second,
- spu.minute,
- spu.hour,
- spu.day,
- ],
- PT.Angle: [
- spu.radian,
- spu.degree,
- ],
- PT.SolidAngle: [
- spu.steradian,
- ],
- PT.Freq: (
- _valid_freqs := [
- terahertz,
- spu.hertz,
- kilohertz,
- megahertz,
- gigahertz,
- petahertz,
- exahertz,
- ]
- ),
- PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs],
- # Cartesian
- PT.Length: (
- _valid_lens := [
- spu.micrometer,
- spu.nanometer,
- spu.picometer,
- spu.angstrom,
- spu.millimeter,
- spu.centimeter,
- spu.meter,
- spu.inch,
- spu.foot,
- spu.yard,
- spu.mile,
- ]
- ),
- PT.Area: [_unit**2 for _unit in _valid_lens],
- PT.Volume: [_unit**3 for _unit in _valid_lens],
- # Mechanical
- PT.Vel: [_unit / spu.second for _unit in _valid_lens],
- PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
- PT.Mass: [
- spu.kilogram,
- spu.electron_rest_mass,
- spu.dalton,
- spu.microgram,
- spu.milligram,
- spu.gram,
- spu.metric_ton,
- ],
- PT.Force: [
- micronewton,
- nanonewton,
- millinewton,
- spu.newton,
- spu.kg * spu.meter / spu.second**2,
- ],
- PT.Pressure: [
- spu.bar,
- millibar,
- spu.pascal,
- hectopascal,
- spu.atmosphere,
- spu.psi,
- spu.mmHg,
- spu.torr,
- ],
- # Energy
- PT.Work: [
- spu.joule,
- spu.electronvolt,
- ],
- PT.Power: [
- spu.watt,
- ],
- PT.PowerFlux: [
- spu.watt / spu.meter**2,
- ],
- PT.Temp: [
- spu.kelvin,
- ],
- # Electrodynamics
- PT.Current: [
- spu.ampere,
- ],
- PT.CurrentDensity: [
- spu.ampere / spu.meter**2,
- ],
- PT.Charge: [
- spu.coulomb,
- ],
- PT.Voltage: [
- spu.volt,
- ],
- PT.Capacitance: [
- spu.farad,
- ],
- PT.Impedance: [
- spu.ohm,
- ],
- PT.Conductance: [
- spu.siemens,
- ],
- PT.Conductivity: [
- spu.siemens / spu.micrometer,
- spu.siemens / spu.meter,
- ],
- PT.MFlux: [
- spu.weber,
- ],
- PT.MFluxDensity: [
- spu.tesla,
- ],
- PT.Inductance: [
- spu.henry,
- ],
- PT.EField: [
- spu.volt / spu.micrometer,
- spu.volt / spu.meter,
- ],
- PT.HField: [
- spu.ampere / spu.micrometer,
- spu.ampere / spu.meter,
- ],
- # Luminal
- PT.LumIntensity: [
- spu.candela,
- ],
- PT.LumFlux: [
- spu.candela * spu.steradian,
- ],
- PT.Illuminance: [
- spu.candela / spu.meter**2,
- ],
- }[self]
-
- @staticmethod
- def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None:
- """Attempt to determine a matching `PhysicalType` from a unit.
-
- NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`.
-
- Returns:
- The matched `PhysicalType`.
-
- If none could be matched, then either return `None` (if `optional` is set) or error.
-
- Raises:
- ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
- """
- if unit is None:
- return PhysicalType.NonPhysical
-
- ## TODO_ This enough?
- if unit in [spu.radian, spu.degree]:
- return PhysicalType.Angle
-
- unit_dim_deps = unit_to_unit_dim_deps(unit)
- if unit_dim_deps is not None:
- for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
- if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps):
- return physical_type
-
- if optional:
- return None
- msg = f'Could not determine PhysicalType for {unit}'
- raise ValueError(msg)
-
- @staticmethod
- def from_unit_dim(
- unit_dim: SympyType | None, optional: bool = False
- ) -> typ.Self | None:
- """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`.
-
- For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`).
- To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions.
-
- Returns:
- The matched `PhysicalType`.
-
- If none could be matched, then either return `None` (if `optional` is set) or error.
-
- Raises:
- ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
- """
- for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
- if compare_unit_dims(unit_dim, candidate_unit_dim):
- return physical_type
-
- if optional:
- return None
- msg = f'Could not determine PhysicalType for {unit_dim}'
- raise ValueError(msg)
-
- @functools.cached_property
- def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]:
- PT = PhysicalType
- overrides = {
- # Cartesian
- PT.Length: [None, (2,), (3,)],
- # Mechanical
- PT.Vel: [None, (2,), (3,)],
- PT.Accel: [None, (2,), (3,)],
- PT.Force: [None, (2,), (3,)],
- # Energy
- PT.Work: [None, (2,), (3,)],
- PT.PowerFlux: [None, (2,), (3,)],
- # Electrodynamics
- PT.CurrentDensity: [None, (2,), (3,)],
- PT.MFluxDensity: [None, (2,), (3,)],
- PT.EField: [None, (2,), (3,)],
- PT.HField: [None, (2,), (3,)],
- # Luminal
- PT.LumFlux: [None, (2,), (3,)],
- }
-
- return overrides.get(self, [None])
-
- @functools.cached_property
- def valid_mathtypes(self) -> list[MathType]:
- """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued.
-
- Generally, all unit quantities are real, in the algebraic mathematical sense.
- However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening.
- This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems.
- In general, the value is a phasor.
-
- While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply.
-
- Notes:
- - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation.
- - **Current**/**Voltage**: The imaginary part represents the phase.
- This also holds for any downstream units.
- - **Charge**: Generally, it is real.
- However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers:
- - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
-
- """
- MT = MathType
- PT = PhysicalType
- overrides = {
- PT.NonPhysical: list(MT), ## Support All
- # Cartesian
- PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
- PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
- # Mechanical
- # Energy
- # Electrodynamics
- PT.Current: [MT.Real, MT.Complex], ## Im -> Phase
- PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase
- PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase
- PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase
- PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase
- PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance
- PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction
- PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction
- PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction
- PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase
- PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase
- PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
- PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
- # Luminal
- }
-
- return overrides.get(self, [MT.Real])
-
- @staticmethod
- def to_name(value: typ.Self) -> str:
- if value is PhysicalType.NonPhysical:
- return 'Unitless'
- return PhysicalType(value).name
-
- @staticmethod
- def to_icon(value: typ.Self) -> str:
- return ''
-
- def bl_enum_element(self, i: int) -> ct.BLEnumElement:
- PT = PhysicalType
- return (
- str(self),
- PT.to_name(self),
- PT.to_name(self),
- PT.to_icon(self),
- i,
- )
-
-
-####################
-# - Standard Unit Systems
-####################
-UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit]
-
-_PT = PhysicalType
-UNITS_SI: UnitSystem = {
- _PT.NonPhysical: None,
- # Global
- _PT.Time: spu.second,
- _PT.Angle: spu.radian,
- _PT.SolidAngle: spu.steradian,
- _PT.Freq: spu.hertz,
- _PT.AngFreq: spu.radian * spu.hertz,
- # Cartesian
- _PT.Length: spu.meter,
- _PT.Area: spu.meter**2,
- _PT.Volume: spu.meter**3,
- # Mechanical
- _PT.Vel: spu.meter / spu.second,
- _PT.Accel: spu.meter / spu.second**2,
- _PT.Mass: spu.kilogram,
- _PT.Force: spu.newton,
- # Energy
- _PT.Work: spu.joule,
- _PT.Power: spu.watt,
- _PT.PowerFlux: spu.watt / spu.meter**2,
- _PT.Temp: spu.kelvin,
- # Electrodynamics
- _PT.Current: spu.ampere,
- _PT.CurrentDensity: spu.ampere / spu.meter**2,
- _PT.Voltage: spu.volt,
- _PT.Capacitance: spu.farad,
- _PT.Impedance: spu.ohm,
- _PT.Conductance: spu.siemens,
- _PT.Conductivity: spu.siemens / spu.meter,
- _PT.MFlux: spu.weber,
- _PT.MFluxDensity: spu.tesla,
- _PT.Inductance: spu.henry,
- _PT.EField: spu.volt / spu.meter,
- _PT.HField: spu.ampere / spu.meter,
- # Luminal
- _PT.LumIntensity: spu.candela,
- _PT.LumFlux: lumen,
- _PT.Illuminance: spu.lux,
-}
-
-
-####################
-# - Sympy Utilities: Cast to Python
-####################
-def sympy_to_python(
- scalar: sp.Basic, use_jax_array: bool = False
-) -> int | float | complex | tuple | jax.Array:
- """Convert a scalar sympy expression to the directly corresponding Python type.
-
- Arguments:
- scalar: A sympy expression that has no symbols, but is expressed as a Sympy type.
- For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter.
-
- Returns:
- A pure Python type that directly corresponds to the input scalar expression.
- """
- if isinstance(scalar, sp.MatrixBase):
- # Detect Single Column Vector
- ## --> Flatten to Single Row Vector
- if len(scalar.shape) == 2 and scalar.shape[1] == 1:
- _scalar = scalar.T
- else:
- _scalar = scalar
-
- # Convert to Tuple of Tuples
- matrix = tuple(
- [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()]
- )
-
- # Detect Single Row Vector
- ## --> This could be because the scalar had it.
- ## --> This could also be because we flattened a column vector.
- ## Either way, we should strip the pointless dimensions.
- if len(matrix) == 1:
- return matrix[0] if not use_jax_array else jnp.array(matrix[0])
-
- return matrix if not use_jax_array else jnp.array(matrix)
- if scalar.is_integer:
- return int(scalar)
- if scalar.is_rational or scalar.is_real:
- return float(scalar)
- if scalar.is_complex:
- return complex(scalar)
-
- msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001
- raise ValueError(msg)
-
-
-####################
-# - Convert to Unit System
-####################
-def strip_unit_system(
- sp_obj: SympyExpr, unit_system: UnitSystem | None = None
-) -> SympyExpr:
- """Strip units occurring in the given unit system from the expression.
-
- Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
- Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_.
-
- Notes:
- You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
- """
- if unit_system is None:
- return sp_obj.subs(UNIT_TO_1)
- return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
-
-
-def convert_to_unit_system(
- sp_obj: SympyExpr, unit_system: UnitSystem | None
-) -> SympyExpr:
- """Convert an expression to the units of a given unit system."""
- if unit_system is None:
- return sp_obj
-
- return spu.convert_to(
- sp_obj,
- {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
- )
-
-
-def scale_to_unit_system(
- sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
-) -> int | float | complex | tuple | jax.Array:
- """Convert an expression to the units of a given unit system, then strip all units of the unit system.
-
- Afterwards, it is converted to an appropriate Python type.
-
- Notes:
- For stability, and performance, reasons, this should only be used at the very last stage.
-
- Regarding performance: **This is not a fast function**.
-
- Parameters:
- sp_obj: An arbitrary sympy object, presumably with units.
- unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units.
- Note that, in this context, only `unit_system.values()` is used.
-
- Returns:
- An appropriate pure Python type, after scaling to the unit system and stripping all units away.
-
- If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`.
- """
- return sympy_to_python(
- strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system),
- use_jax_array=use_jax_array,
- )
diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py
index 6ca4412..68820c3 100644
--- a/src/blender_maxwell/utils/image_ops.py
+++ b/src/blender_maxwell/utils/image_ops.py
@@ -30,7 +30,7 @@ import matplotlib.figure
import seaborn as sns
from blender_maxwell import contracts as ct
-from blender_maxwell.utils import extra_sympy_units as spux
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
sns.set_theme()
diff --git a/src/blender_maxwell/utils/sci_constants.py b/src/blender_maxwell/utils/sci_constants.py
index 920bb2d..935e24a 100644
--- a/src/blender_maxwell/utils/sci_constants.py
+++ b/src/blender_maxwell/utils/sci_constants.py
@@ -32,7 +32,7 @@ import scipy as sc
import sympy as sp
import sympy.physics.units as spu
-from . import extra_sympy_units as spux
+from . import sympy_extra as spux
SUPPORTED_SCIPY_PREFIX = '1.12'
if not sc.version.full_version.startswith(SUPPORTED_SCIPY_PREFIX):
diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py
index bab40a0..b58aebc 100644
--- a/src/blender_maxwell/utils/serialize.py
+++ b/src/blender_maxwell/utils/serialize.py
@@ -31,8 +31,8 @@ import uuid
import msgspec
import sympy as sp
-from . import extra_sympy_units as spux
from . import logger
+from . import sympy_extra as spux
log = logger.get(__name__)
diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py
index d3e4366..f663d36 100644
--- a/src/blender_maxwell/utils/sim_symbols.py
+++ b/src/blender_maxwell/utils/sim_symbols.py
@@ -25,8 +25,8 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from . import extra_sympy_units as spux
from . import logger, serialize
+from . import sympy_extra as spux
int_min = -(2**64)
int_max = 2**64
@@ -101,6 +101,12 @@ class SimSymbolName(enum.StrEnum):
BlochY = enum.auto()
BlochZ = enum.auto()
+ # New Backwards Compatible Entries
+ ## -> Ordered lists carry a particular enum integer index.
+ ## -> Therefore, anything but adding an index breaks backwards compat.
+ ## -> ...With all previous files.
+ ConstantRange = enum.auto()
+
####################
# - UI
####################
@@ -143,7 +149,8 @@ class SimSymbolName(enum.StrEnum):
}
| {
# Generic
- SSN.Constant: 'constant',
+ SSN.Constant: 'cst',
+ SSN.ConstantRange: 'cst_range',
SSN.Expr: 'expr',
SSN.Data: 'data',
# Greek Letters
@@ -210,24 +217,43 @@ def mk_interval(
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
interval_inf: tuple[bool, bool],
interval_closed: tuple[bool, bool],
- unit_factor: typ.Literal[1] | spux.Unit,
) -> sp.Interval:
"""Create a symbolic interval from the tuples (and unit) defining it."""
return sp.Interval(
- start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
- end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
+ start=(interval_finite[0] if not interval_inf[0] else -sp.oo),
+ end=(interval_finite[1] if not interval_inf[1] else sp.oo),
left_open=(True if interval_inf[0] else not interval_closed[0]),
right_open=(True if interval_inf[1] else not interval_closed[1]),
)
class SimSymbol(pyd.BaseModel):
- """A declarative representation of a symbolic variable.
+ """A convenient, constrained representation of a symbolic variable suitable for many tasks.
- `sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
+ The original motivation was to enhance `sp.Symbol` with greater flexibility, semantic context, and a UI-friendly representation.
+ Today, `SimSymbol` is a fully capable primitive for defining the interfaces between externally tracked mathematical elements, and planning the required operations between them.
+
+ A symbol represented as `SimSymbol` carries all the semantic meaning of that symbol, and comes with a comprehensive library of useful (computed) properties and methods.
+ It is immutable, hashable, and serializable, and as a `pydantic.BaseModel` with aggressive property caching, its performance properties should also be well-suited for use in the hot-loops of ex. UI draw methods.
+
+ Attributes:
+ sym_name: For humans and computers, symbol names induces a lot of implicit semantics.
+ mathtype: Symbols are associated with some set of valid values.
+ We choose to constrain `SimSymbol` to only associate with _mathematical_ (aka. number-like) sets.
+ This prohibits ex. booleans and predicate-logic applications, but eases a lot of burdens associated with actually using `SimSymbol`.
+ physical_type: Symbols may be associated with a particular unit dimension expression.
+ This allows the symbol to have _physical meaning_.
+ This information is **generally not** encoded in auxiliary attributes like `self.domain`, but **generally is** encoded by computed properties/methods.
+ unit: Symbols may be associated with a particular unit, which must be compatible with the `PhysicalType`.
+ **NOTE**: Unit expressions may well have physical meaning, without being strictly conformable to a pre-blessed `PhysicalType`s.
+ We do try to avoid such cases, but for the sake of correctness, our chosen convention is to let `self.physical_type` be "`NonPhysical`", while still allowing a unit.
+ size: Symbols may themselves have shape.
+ **NOTE**: We deliberately choose to constrain `SimSymbol`s to two dimensions, allowing them to represent scalars, vectors, covectors, and matrices, but **not** arbitrary tensors.
+ This is a practical tradeoff, made both to make it easier (in terms of mathematical analysis) to implement `SimSymbol`, but also to make it easier to define UI elements that drive / are driven by `SimSymbol`s.
+ domain: Symbols are associated with a _domain of valid values_, expressed with any mathematical set implemented as a subclass of `sympy.Set`.
+ By using a true symbolic set, we gain unbounded flexibility in how to define the validity of a set, including an extremely capable `* in self.domain` operator encapsulating a lot of otherwise very manual logic.
+ **NOTE** that `self.unit` is **not** baked into the domain, due to practicalities associated with subclasses of `sp.Set`.
- This dataclass is UI-friendly, as it only uses field type annotations/defaults supported by `bl_cache.BLProp`.
- It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
"""
model_config = pyd.ConfigDict(frozen=True)
@@ -238,11 +264,8 @@ class SimSymbol(pyd.BaseModel):
# Units
## -> 'None' indicates that no particular unit has yet been chosen.
- ## -> Not exposed in the UI; must be set some other way.
+ ## -> When 'self.physical_type' is NonPhysical, can only be None.
unit: spux.Unit | None = None
- ## -> TODO: We currently allowing units that don't match PhysicalType
- ## -> -- In particular, NonPhysical w/units means "unknown units".
- ## -> -- This is essential for the Scientific Constant Node.
# Size
## -> All SimSymbol sizes are "2D", but interpreted by convention.
@@ -253,39 +276,96 @@ class SimSymbol(pyd.BaseModel):
rows: int = 1
cols: int = 1
- # Scalar Domain: "Interval"
- ## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
- ## -> See self.domain.
- ## -> We have to deconstruct symbolic interval semantics a bit for UI.
- is_constant: bool = False
- exclude_zero: bool = False
+ # Valid Domain
+ ## -> Declares the valid set of values that may be given to this symbol.
+ ## -> By convention, units are not encoded in the domain sp.Set.
+ ## -> 'sp.Set's are extremely expressive and cool.
+ domain: spux.SympyExpr | None = None
- interval_finite_z: tuple[int, int] = (0, 1)
- interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
- interval_finite_re: tuple[float, float] = (0.0, 1.0)
- interval_inf: tuple[bool, bool] = (True, True)
- interval_closed: tuple[bool, bool] = (False, False)
+ @functools.cached_property
+ def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet:
+ if self.rows > 1 or self.cols > 1:
+ return sp.matrices.MatrixSet(self.rows, self.cols, self.domain)
+ return self.domain
- interval_finite_im: tuple[float, float] = (0.0, 1.0)
- interval_inf_im: tuple[bool, bool] = (True, True)
- interval_closed_im: tuple[bool, bool] = (False, False)
-
- preview_value_z: int = 0
- preview_value_q: tuple[int, int] = (0, 1)
- preview_value_re: float = 0.0
- preview_value_im: float = 0.0
+ preview_value: spux.SympyExpr | None = None
####################
- # - Core
+ # - Validators
+ ####################
+ ## TODO: Check domain against MathType
+ ## -- Surprisingly hard without a lot of special-casing.
+
+ ## TODO: Check that size is valid for the PhysicalType.
+
+ ## TODO: Check that constant value (domain=FiniteSet(cst)) is compatible with the MathType.
+
+ ## TODO: Check that preview_value is in the domain.
+
+ @pyd.model_validator(mode='after')
+ def set_undefined_domain_from_mathtype(self) -> typ.Self:
+ """When the domain is not set, then set it using the symbolic set of the MathType."""
+ if self.domain is None:
+ object.__setattr__(self, 'domain', self.mathtype.symbolic_set)
+ return self
+
+ @pyd.model_validator(mode='after')
+ def conform_undefined_preview_value_to_constant(self) -> typ.Self:
+ """When the `SimSymbol` is a constant, but the preview value is not set, then set the preview value from the constant."""
+ if self.is_constant and not self.preview_value:
+ object.__setattr__(self, 'preview_value', self.constant_value)
+ return self
+
+ @pyd.model_validator(mode='after')
+ def conform_preview_value(self) -> typ.Self:
+ """Conform the given preview value to the `SimSymbol`."""
+ if self.is_constant and not self.preview_value:
+ object.__setattr__(
+ self,
+ 'preview_value',
+ self.conform(self.preview_value, strip_units=True),
+ )
+ return self
+
+ ####################
+ # - Domain
####################
@functools.cached_property
- def name(self) -> str:
- """Usable name for the symbol."""
- return self.sym_name.name
+ def is_constant(self) -> bool:
+ """When the symbol domain is a single-element `sp.FiniteSet`, then the symbol can be considered to be a constant."""
+ return isinstance(self.domain, sp.FiniteSet) and len(self.domain) == 1
+
+ @functools.cached_property
+ def constant_value(self) -> bool:
+ """Get the constant when `is_constant` is True.
+
+ The `self.unit_factor` is multiplied onto the constant at this point.
+ """
+ if self.is_constant:
+ return next(iter(self.domain)) * self.unit_factor
+
+ msg = 'Tried to get constant value of non-constant SimSymbol.'
+ raise ValueError(msg)
+
+ @functools.cached_property
+ def is_nonzero(self) -> bool:
+ """Whether $0$ is a valid value for this symbol.
+
+ When shaped, $0$ refers to the relevant shaped object with all elements $0$.
+
+ Notes:
+ Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
+ """
+ return 0 in self.domain
####################
# - Labels
####################
+ @functools.cached_property
+ def name(self) -> str:
+ """Usable string name for the symbol."""
+ return self.sym_name.name
+
@functools.cached_property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
@@ -340,7 +420,8 @@ class SimSymbol(pyd.BaseModel):
return self.unit if self.unit is not None else sp.S(1)
@functools.cached_property
- def size(self) -> tuple[int, ...] | None:
+ def size(self) -> spux.NumberSize1D | None:
+ """The 1D number size of this `SimSymbol`, if it has one; else None."""
return {
(1, 1): spux.NumberSize1D.Scalar,
(2, 1): spux.NumberSize1D.Vec2,
@@ -350,13 +431,17 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property
def shape(self) -> tuple[int, ...]:
+ """Deterministic chosen shape of this `SimSymbol`.
+
+ Derived from `self.rows` and `self.cols`.
+
+ Is never `None`; instead, empty tuple `()` is used.
+ """
match (self.rows, self.cols):
case (1, 1):
return ()
case (_, 1):
return (self.rows,)
- case (1, _):
- return (1, self.rows)
case (_, _):
return (self.rows, self.cols)
@@ -365,116 +450,6 @@ class SimSymbol(pyd.BaseModel):
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return len(self.shape)
- @functools.cached_property
- def domain(self) -> sp.Interval | sp.Set:
- """Return the scalar domain of valid values for each element of the symbol.
-
- For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
- This interval **must** have the property`start <= stop`.
-
- Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
- """
- match self.mathtype:
- case spux.MathType.Integer:
- return mk_interval(
- self.interval_finite_z,
- self.interval_inf,
- self.interval_closed,
- self.unit_factor,
- )
-
- case spux.MathType.Rational:
- return mk_interval(
- Fraction(*self.interval_finite_q),
- self.interval_inf,
- self.interval_closed,
- self.unit_factor,
- )
-
- case spux.MathType.Real:
- return mk_interval(
- self.interval_finite_re,
- self.interval_inf,
- self.interval_closed,
- self.unit_factor,
- )
-
- case spux.MathType.Complex:
- return (
- mk_interval(
- self.interval_finite_re,
- self.interval_inf,
- self.interval_closed,
- self.unit_factor,
- ),
- mk_interval(
- self.interval_finite_im,
- self.interval_inf_im,
- self.interval_closed_im,
- self.unit_factor,
- ),
- )
-
- @functools.cached_property
- def valid_domain_value(self) -> spux.SympyExpr:
- """A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
- match (self.domain.start.is_finite, self.domain.end.is_finite):
- case (True, True):
- if self.mathtype is spux.MathType.Integer:
- return (self.domain.start + self.domain.end) // 2
- return (self.domain.start + self.domain.end) / 2
-
- case (True, False):
- one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
- return self.domain.start + one
-
- case (False, True):
- one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
- return self.domain.end - one
-
- case (False, False):
- return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
-
- @functools.cached_property
- def is_nonzero(self) -> bool:
- """Whether or not the value of this symbol can ever be $0$.
-
- Notes:
- Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
- """
- if self.exclude_zero:
- return True
-
- def check_real_domain(real_domain):
- return (
- (
- real_domain.left == 0
- and real_domain.left_open
- or real_domain.right == 0
- and real_domain.right_open
- )
- or real_domain.left > 0
- or real_domain.right < 0
- )
-
- if self.mathtype is spux.MathType.Complex:
- return check_real_domain(self.domain[0]) and check_real_domain(
- self.domain[1]
- )
- return check_real_domain(self.domain)
-
- @functools.cached_property
- def can_diff(self) -> bool:
- """Whether this symbol can be used as the input / output variable when differentiating."""
- # Check Constants
- ## -> Constants (w/pinned values) are never differentiable.
- if self.is_constant:
- return False
-
- # TODO: Discontinuities (especially across 0)?
-
- return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
-
####################
# - Properties
####################
@@ -511,9 +486,9 @@ class SimSymbol(pyd.BaseModel):
# Positive/Negative Assumption
if self.mathtype is not spux.MathType.Complex:
- if self.domain.left >= 0:
+ if self.domain.inf >= 0:
mathtype_kwargs |= {'positive': True}
- elif self.domain.right <= 0:
+ elif self.domain.sup < 0:
mathtype_kwargs |= {'negative': True}
# Scalar: Return Symbol
@@ -571,7 +546,7 @@ class SimSymbol(pyd.BaseModel):
"""
if self.size is not None:
if self.unit in self.physical_type.valid_units:
- return {
+ socket_info = {
'output_name': self.sym_name,
# Socket Interface
'size': self.size,
@@ -580,23 +555,42 @@ class SimSymbol(pyd.BaseModel):
# Defaults: Units
'default_unit': self.unit,
'default_symbols': [],
- # Defaults: FlowKind.Value
- 'default_value': self.conform(
- self.valid_domain_value, strip_unit=True
- ),
- # Defaults: FlowKind.Range
- 'default_min': self.conform(self.domain.start, strip_unit=True),
- 'default_max': self.conform(self.domain.end, strip_unit=True),
}
+
+ # Defaults: FlowKind.Value
+ if self.preview_value:
+ socket_info |= {
+ 'default_value': self.conform(
+ self.preview_value, strip_unit=True
+ )
+ }
+
+ # Defaults: FlowKind.Range
+ if (
+ self.mathtype is not spux.MathType.Complex
+ and self.rows == 1
+ and self.cols == 1
+ ):
+ socket_info |= {
+ 'default_min': self.domain.inf,
+ 'default_max': self.domain.sup,
+ }
+ ## TODO: Handle discontinuities / disjointness / open boundaries.
+
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
raise NotImplementedError(msg)
+
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
raise NotImplementedError(msg)
####################
- # - Operations
+ # - Operations: Raw Update
####################
def update(self, **kwargs) -> typ.Self:
+ """Create a new `SimSymbol`, such that the given keyword arguments override the existing values."""
+ if not kwargs:
+ return self
+
def get_attr(attr: str):
_notfound = 'notfound'
if kwargs.get(attr, _notfound) is _notfound:
@@ -610,61 +604,101 @@ class SimSymbol(pyd.BaseModel):
unit=get_attr('unit'),
rows=get_attr('rows'),
cols=get_attr('cols'),
- interval_finite_z=get_attr('interval_finite_z'),
- interval_finite_q=get_attr('interval_finite_q'),
- interval_finite_re=get_attr('interval_finite_re'),
- interval_inf=get_attr('interval_inf'),
- interval_closed=get_attr('interval_closed'),
- interval_finite_im=get_attr('interval_finite_im'),
- interval_inf_im=get_attr('interval_inf_im'),
- interval_closed_im=get_attr('interval_closed_im'),
+ domain=get_attr('domain'),
)
- def set_finite_domain( # noqa: PLR0913
- self,
- start: int | float,
- end: int | float,
- start_closed: bool = True,
- end_closed: bool = True,
- start_im: bool = float,
- end_im: bool = float,
- start_closed_im: bool = True,
- end_closed_im: bool = True,
- ) -> typ.Self:
- """Update the symbol with a finite range."""
- closed_re = (start_closed, end_closed)
- closed_im = (start_closed_im, end_closed_im)
- match self.mathtype:
- case spux.MathType.Integer:
- return self.update(
- interval_finite_z=(start, end),
- interval_inf=(False, False),
- interval_closed=closed_re,
- )
- case spux.MathType.Rational:
- return self.update(
- interval_finite_q=(start, end),
- interval_inf=(False, False),
- interval_closed=closed_re,
- )
- case spux.MathType.Real:
- return self.update(
- interval_finite_re=(start, end),
- interval_inf=(False, False),
- interval_closed=closed_re,
- )
- case spux.MathType.Complex:
- return self.update(
- interval_finite_re=(start, end),
- interval_finite_im=(start_im, end_im),
- interval_inf=(False, False),
- interval_closed=closed_re,
- interval_closed_im=closed_im,
- )
+ ####################
+ # - Operations: Comparison
+ ####################
+ def compare(self, other: typ.Self) -> typ.Self:
+ """Whether this SimSymbol can be considered equivalent to another, and thus universally usable in arbitrary mathematical operations together.
- def set_size(self, rows: int, cols: int) -> typ.Self:
- return self.update(rows=rows, cols=cols)
+ In particular, two attributes are ignored:
+ - **Name**: The particluar choices of name are not generally important.
+ - **Unit**: The particulars of unit equivilancy are not generally important; only that the `PhysicalType` is equal, and thus that they are compatible.
+ While not usable in all cases, this method ends up being very helpful for simplifying certain checks that would otherwise take up a lot of space.
+ """
+ return (
+ self.mathtype is other.mathtype
+ and self.physical_type is other.physical_type
+ and self.compare_size(other)
+ and self.domain == other.domain
+ )
+
+ def compare_size(self, other: typ.Self) -> typ.Self:
+ """Compare the size of this `SimSymbol` with another."""
+ return self.rows == other.rows and self.cols == other.cols
+
+ def compare_addable(
+ self, other: typ.Self, allow_differing_unit: bool = False
+ ) -> bool:
+ """Whether two `SimSymbol`s can be added."""
+ common = (
+ self.compare_size(other.output)
+ and self.physical_type is other.physical_type
+ and not (
+ self.physical_type is spux.NonPhysical
+ and self.unit is not None
+ and self.unit != other.unit
+ )
+ and not (
+ other.physical_type is spux.NonPhysical
+ and other.unit is not None
+ and self.unit != other.unit
+ )
+ )
+ if not allow_differing_unit:
+ return common and self.output.unit == other.output.unit
+ return common
+
+ def compare_multiplicable(self, other: typ.Self) -> bool:
+ """Whether two `SimSymbol`s can be multiplied."""
+ return self.shape_len == 0 or self.compare_size(other)
+
+ def compare_exponentiable(self, other: typ.Self) -> bool:
+ """Whether two `SimSymbol`s can be exponentiated.
+
+ "Hadamard Power" is defined for any combination of scalar/vector/matrix operands, for any `MathType` combination.
+ The only important thing to check is that the exponent cannot have a physical unit.
+
+ Sometimes, people write equations with units in the exponent.
+ This is a notational shorthand that only works in the context of an implicit, cancelling factor.
+ We reject such things.
+
+ See https://physics.stackexchange.com/questions/109995/exponential-or-logarithm-of-a-dimensionful-quantity
+ """
+ return (
+ other.physical_type is spux.PhysicalType.NonPhysical and other.unit is None
+ )
+
+ ####################
+ # - Operations: Copying Setters
+ ####################
+ def set_constant(self, constant_value: spux.SympyType) -> typ.Self:
+ """Set the constant value of this `SimSymbol`, by setting it as the only value in a `sp.FiniteSet` domain.
+
+ The `constant_value` will be conformed and stripped (with `self.conform()`) before being injected into the new `sp.FiniteSet` domain.
+
+ Warnings:
+ Keep in mind that domains do not encode units, for practical reasons related to the diverging ways in which various `sp.Set` subclasses interpret units.
+
+ This isn't noticeable in normal constant-symbol workflows, where the constant is retrieved using `self.constant_value` (which adds `self.unit_factor`).
+ However, **remember that retrieving the domain directly won't add the unit**.
+
+ Ye been warned!
+ """
+ if self.is_constant:
+ return self.update(
+ domain=sp.FiniteSet(self.conform(constant_value, strip_unit=True))
+ )
+
+ msg = 'Tried to set constant value of non-constant SimSymbol.'
+ raise ValueError(msg)
+
+ ####################
+ # - Operations: Conforming Mappers
+ ####################
def conform(
self, sp_obj: spux.SympyType, strip_unit: bool = False
) -> spux.SympyType:
@@ -732,6 +766,9 @@ class SimSymbol(pyd.BaseModel):
return res # noqa: RET504
+ ####################
+ # - Creation
+ ####################
@staticmethod
def from_expr(
sym_name: SimSymbolName,
diff --git a/src/blender_maxwell/utils/sympy_extra/__init__.py b/src/blender_maxwell/utils/sympy_extra/__init__.py
new file mode 100644
index 0000000..6ac4fca
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/__init__.py
@@ -0,0 +1,173 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Declares many useful primitives to greatly simplify working with `sympy` in the context of a unit-aware system."""
+
+from .math_type import MathType
+from .number_size import NumberSize1D, NumberSize2D
+from .parse_cast import parse_shape, pretty_symbol, sp_to_str, sympy_to_python
+from .physical_type import Dims, PhysicalType
+from .sympy_expr import (
+ ComplexNumber,
+ ComplexSymbol,
+ ConstrSympyExpr,
+ IntNumber,
+ IntSymbol,
+ Number,
+ PhysicalComplexNumber,
+ PhysicalNumber,
+ PhysicalRealNumber,
+ RationalSymbol,
+ Real3DVector,
+ RealNumber,
+ RealSymbol,
+ ScalarUnitlessComplexExpr,
+ ScalarUnitlessRealExpr,
+ Symbol,
+ SympyExpr,
+ Unit,
+ UnitDimension,
+)
+from .sympy_type import SympyType
+from .unit_analysis import (
+ compare_unit_dims,
+ compare_units_by_unit_dims,
+ convert_to_unit,
+ get_units,
+ scale_to_unit,
+ scaling_factor,
+ strip_units,
+ unit_dim_to_unit_dim_deps,
+ unit_str_to_unit,
+ unit_to_unit_dim_deps,
+ uses_units,
+)
+from .unit_system_analysis import (
+ convert_to_unit_system,
+ scale_to_unit_system,
+ strip_unit_system,
+)
+from .unit_systems import UNITS_SI, UnitSystem
+from .units import (
+ UNIT_BY_SYMBOL,
+ UNIT_TO_1,
+ EHz,
+ GHz,
+ KHz,
+ MHz,
+ PHz,
+ THz,
+ exahertz,
+ femtometer,
+ femtosecond,
+ fm,
+ fs,
+ gigahertz,
+ hectopascal,
+ hPa,
+ kilohertz,
+ lm,
+ lumen,
+ mbar,
+ megahertz,
+ micronewton,
+ millibar,
+ millinewton,
+ mN,
+ nanonewton,
+ nN,
+ petahertz,
+ terahertz,
+ uN,
+)
+
+__all__ = [
+ 'MathType',
+ 'NumberSize1D',
+ 'NumberSize2D',
+ 'parse_shape',
+ 'pretty_symbol',
+ 'sp_to_str',
+ 'sympy_to_python',
+ 'Dims',
+ 'PhysicalType',
+ 'ComplexNumber',
+ 'ComplexSymbol',
+ 'ConstrSympyExpr',
+ 'IntNumber',
+ 'IntSymbol',
+ 'Number',
+ 'PhysicalComplexNumber',
+ 'PhysicalNumber',
+ 'PhysicalRealNumber',
+ 'RationalSymbol',
+ 'Real3DVector',
+ 'RealNumber',
+ 'RealSymbol',
+ 'ScalarUnitlessComplexExpr',
+ 'ScalarUnitlessRealExpr',
+ 'Symbol',
+ 'SympyExpr',
+ 'Unit',
+ 'UnitDimension',
+ 'SympyType',
+ 'compare_unit_dims',
+ 'compare_units_by_unit_dims',
+ 'convert_to_unit',
+ 'get_units',
+ 'scale_to_unit',
+ 'scaling_factor',
+ 'strip_units',
+ 'unit_dim_to_unit_dim_deps',
+ 'unit_str_to_unit',
+ 'unit_to_unit_dim_deps',
+ 'uses_units',
+ 'strip_unit_system',
+ 'UNITS_SI',
+ 'UnitSystem',
+ 'convert_to_unit_system',
+ 'scale_to_unit_system',
+ 'UNIT_BY_SYMBOL',
+ 'UNIT_TO_1',
+ 'EHz',
+ 'GHz',
+ 'KHz',
+ 'MHz',
+ 'PHz',
+ 'THz',
+ 'exahertz',
+ 'femtometer',
+ 'femtosecond',
+ 'fm',
+ 'fs',
+ 'gigahertz',
+ 'hectopascal',
+ 'hPa',
+ 'kilohertz',
+ 'lm',
+ 'lumen',
+ 'mbar',
+ 'megahertz',
+ 'micronewton',
+ 'millibar',
+ 'millinewton',
+ 'mN',
+ 'nanonewton',
+ 'nN',
+ 'petahertz',
+ 'terahertz',
+ 'uN',
+]
diff --git a/src/blender_maxwell/utils/sympy_extra/math_type.py b/src/blender_maxwell/utils/sympy_extra/math_type.py
new file mode 100644
index 0000000..8830d85
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/math_type.py
@@ -0,0 +1,362 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements `MathType`, a convenient UI-friendly identifier of numerical identity."""
+
+import enum
+import sys
+import typing as typ
+from fractions import Fraction
+
+import jax
+import jaxtyping as jtyp
+import sympy as sp
+
+from blender_maxwell import contracts as ct
+
+from .. import logger
+from .sympy_type import SympyType
+
+log = logger.get(__name__)
+
+
+class MathType(enum.StrEnum):
+ """A convenient, UI-friendly identifier of a numerical object's identity."""
+
+ Integer = enum.auto()
+ Rational = enum.auto()
+ Real = enum.auto()
+ Complex = enum.auto()
+
+ ####################
+ # - Checks
+ ####################
+ @staticmethod
+ def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'jax', 'expr'] | None:
+ """Determine whether an object of arbitrary type can be considered to have a `MathType`.
+
+ - **Pure Python**: The numerical Python types (`int | Fraction | float | complex`) are all valid.
+ - **Expression**: Sympy types / expression are in general considered to have a valid MathType.
+ - **Jax**: Non-empty `jax` arrays with a valid numerical Python type as the first element are valid.
+
+ Returns:
+ A string literal indicating how to parse the object for a valid `MathType`.
+
+ If the presence of a MathType couldn't be deduced, then return None.
+ """
+ if isinstance(obj, int | Fraction | float | complex):
+ return 'pytype'
+
+ if (
+ isinstance(obj, jax.Array)
+ and obj
+ and isinstance(obj.item(0), int | Fraction | float | complex)
+ ):
+ return 'jax'
+
+ if isinstance(obj, sp.Basic | sp.MatrixBase):
+ return 'expr'
+ ## TODO: Should we check deeper?
+
+ return None
+
+ ####################
+ # - Creation
+ ####################
+ @staticmethod
+ def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: # noqa: PLR0911
+ """Deduce the `MathType` of an arbitrary sympy object (/expression).
+
+ The "assumptions" system of `sympy` is relied on to determine the key properties of the expression.
+ To this end, it's important to note several of the shortcomings of the "assumptions" system:
+
+ - All elements, especially symbols, must have well-defined assumptions, ex. `real=True`.
+ - Only the "narrowest" possible `MathType` will be deduced, ex. `5` may well be the result of a complex expression, but since it is now an integer, it will parse to `MathType.Integer`. This may break some
+ - For infinities, only real and complex infinities are distinguished between in `sympy` (`sp.oo` vs. `sp.zoo`) - aka. there is no "integer infinity" which will parse to `Integer` with this method.
+
+ Warnings:
+ Using the "assumptions" system like this requires a lot of rigor in the entire program.
+
+ Notes:
+ Any matrix-like object will have `MathType.combine()` run on all of its (flattened) elements.
+ This is an extremely **slow** operation, but accurate, according to the semantics of `MathType.combine()`.
+
+ Note that `sp.MatrixSymbol` _cannot have assumptions_, and thus shouldn't be used in `sp_obj`.
+
+ Returns:
+ A corresponding `MathType`; else, if `optional=True`, return `None`.
+
+ Raises:
+ ValueError: If no corresponding `MathType` could be determined, and `optional=False`.
+
+ """
+ if isinstance(sp_obj, sp.MatrixBase):
+ return MathType.combine(
+ *[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
+ )
+
+ if sp_obj.is_integer:
+ return MathType.Integer
+ if sp_obj.is_rational:
+ return MathType.Rational
+ if sp_obj.is_real:
+ return MathType.Real
+ if sp_obj.is_complex:
+ return MathType.Complex
+
+ # Infinities
+ if sp_obj in [sp.oo, -sp.oo]:
+ return MathType.Real
+ if sp_obj in [sp.zoo, -sp.zoo]:
+ return MathType.Complex
+
+ if optional:
+ return None
+
+ msg = f"Can't determine MathType from sympy object: {sp_obj}"
+ raise ValueError(msg)
+
+ @staticmethod
+ def from_pytype(dtype: type) -> type:
+ return {
+ int: MathType.Integer,
+ Fraction: MathType.Rational,
+ float: MathType.Real,
+ complex: MathType.Complex,
+ }[dtype]
+
+ @staticmethod
+ def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
+ """Deduce the MathType corresponding to a JAX array.
+
+ We go about this by leveraging that:
+ - `data` is of a homogeneous type.
+ - `data.item(0)` returns a single element of the array w/pure-python type.
+
+ By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
+
+ Notes:
+ Should also work with numpy arrays.
+ """
+ if len(data) > 0:
+ return MathType.from_pytype(type(data.item(0)))
+
+ msg = 'Cannot determine MathType from empty jax array.'
+ raise ValueError(msg)
+
+ ####################
+ # - Operations
+ ####################
+ @staticmethod
+ def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None:
+ if MathType.Complex in mathtypes:
+ return MathType.Complex
+ if MathType.Real in mathtypes:
+ return MathType.Real
+ if MathType.Rational in mathtypes:
+ return MathType.Rational
+ if MathType.Integer in mathtypes:
+ return MathType.Integer
+
+ if optional:
+ return None
+
+ msg = f"Can't combine mathtypes {mathtypes}"
+ raise ValueError(msg)
+
+ def is_compatible(self, other: typ.Self) -> bool:
+ MT = MathType
+ return (
+ other
+ in {
+ MT.Integer: [MT.Integer],
+ MT.Rational: [MT.Integer, MT.Rational],
+ MT.Real: [MT.Integer, MT.Rational, MT.Real],
+ MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex],
+ }[self]
+ )
+
+ def coerce_compatible_pyobj(
+ self, pyobj: bool | int | Fraction | float | complex
+ ) -> int | Fraction | float | complex:
+ """Coerce a pure-python object of numerical type to the _exact_ type indicated by this `MathType`.
+
+ This is needed when ex. one has an integer, but it is important that that integer be passed as a complex number.
+ """
+ MT = MathType
+ match self:
+ case MT.Integer:
+ return int(pyobj)
+ case MT.Rational if isinstance(pyobj, int):
+ return Fraction(pyobj, 1)
+ case MT.Rational if isinstance(pyobj, Fraction):
+ return pyobj
+ case MT.Real:
+ return float(pyobj)
+ case MT.Complex if isinstance(pyobj, int | Fraction):
+ return complex(float(pyobj), 0)
+ case MT.Complex if isinstance(pyobj, float):
+ return complex(pyobj, 0)
+
+ @staticmethod
+ def from_symbolic_set(
+ s: typ.Literal[
+ sp.Naturals
+ | sp.Naturals0
+ | sp.Integers
+ | sp.Rationals
+ | sp.Reals
+ | sp.Complexes
+ ]
+ | sp.Set,
+ optional: bool = False,
+ ) -> typ.Self | None:
+ """Deduce the `MathType` from a particular symbolic set.
+
+ Currently hard-coded.
+ Any deviation that might be expected to work, ex. `sp.Reals - {0}`, currently won't (currently).
+
+ Raises:
+ ValueError: If a non-hardcoded symbolic set is passed.
+ """
+ MT = MathType
+ match s:
+ case sp.Naturals | sp.Naturals0 | sp.Integers:
+ return MT.Integer
+ case sp.Rationals:
+ return MT.Rational
+ case sp.Reals:
+ return MT.Real
+ case sp.Complexes:
+ return MT.Complex
+
+ if optional:
+ return None
+
+ msg = f"Can't deduce MathType from symbolic set {s}"
+ raise ValueError(msg)
+
+ ####################
+ # - Casting: Pytype
+ ####################
+ @property
+ def pytype(self) -> type:
+ """Deduce the pure-Python type that corresponds to this `MathType`."""
+ MT = MathType
+ return {
+ MT.Integer: int,
+ MT.Rational: Fraction,
+ MT.Real: float,
+ MT.Complex: complex,
+ }[self]
+
+ @property
+ def inf_finite(self) -> type:
+ """Opinionated finite representation of "infinity" within this `MathType`.
+
+ These are chosen using `sys.maxsize` and `sys.float_info`.
+ As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
+
+ **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
+
+ Notes:
+ Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
+
+ These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
+ In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
+
+ However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
+ """
+ MT = MathType
+ Z = MT.Integer
+ R = MT.Integer
+ return {
+ MT.Integer: (-sys.maxsize, sys.maxsize),
+ MT.Rational: (
+ Fraction(Z.inf_finite[0], 1),
+ Fraction(Z.inf_finite[1], 1),
+ ),
+ MT.Real: -(sys.float_info.min, sys.float_info.max),
+ MT.Complex: (
+ complex(R.inf_finite[0], R.inf_finite[0]),
+ complex(R.inf_finite[1], R.inf_finite[1]),
+ ),
+ }[self]
+
+ ####################
+ # - Casting: Symbolic
+ ####################
+ @property
+ def symbolic_set(self) -> sp.Set:
+ """Deduce the symbolic `sp.Set` type that corresponds to this `MathType`."""
+ MT = MathType
+ return {
+ MT.Integer: sp.Integers,
+ MT.Rational: sp.Rationals,
+ MT.Real: sp.Reals,
+ MT.Complex: sp.Complexes,
+ }[self]
+
+ @property
+ def sp_symbol_a(self) -> type:
+ MT = MathType
+ return {
+ MT.Integer: sp.Symbol('a', integer=True),
+ MT.Rational: sp.Symbol('a', rational=True),
+ MT.Real: sp.Symbol('a', real=True),
+ MT.Complex: sp.Symbol('a', complex=True),
+ }[self]
+
+ ####################
+ # - Labels
+ ####################
+ @staticmethod
+ def to_str(value: typ.Self) -> type:
+ return {
+ MathType.Integer: 'ℤ',
+ MathType.Rational: 'ℚ',
+ MathType.Real: 'ℝ',
+ MathType.Complex: 'ℂ',
+ }[value]
+
+ @property
+ def name(self) -> str:
+ """Simple non-unicode name of the math type."""
+ return str(self)
+
+ @property
+ def label_pretty(self) -> str:
+ return MathType.to_str(self)
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ return MathType.to_str(value)
+
+ @staticmethod
+ def to_icon(value: typ.Self) -> str:
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ return (
+ str(self),
+ MathType.to_name(self),
+ MathType.to_name(self),
+ MathType.to_icon(self),
+ i,
+ )
diff --git a/src/blender_maxwell/utils/sympy_extra/number_size.py b/src/blender_maxwell/utils/sympy_extra/number_size.py
new file mode 100644
index 0000000..14bf34d
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/number_size.py
@@ -0,0 +1,148 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import sympy as sp
+
+from blender_maxwell import contracts as ct
+
+
+####################
+# - Size: 1D
+####################
+class NumberSize1D(enum.StrEnum):
+ """Valid 1D-constrained shape."""
+
+ Scalar = enum.auto()
+ Vec2 = enum.auto()
+ Vec3 = enum.auto()
+ Vec4 = enum.auto()
+
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ NS = NumberSize1D
+ return {
+ NS.Scalar: 'Scalar',
+ NS.Vec2: '2D',
+ NS.Vec3: '3D',
+ NS.Vec4: '4D',
+ }[value]
+
+ @staticmethod
+ def to_icon(value: typ.Self) -> str:
+ NS = NumberSize1D
+ return {
+ NS.Scalar: '',
+ NS.Vec2: '',
+ NS.Vec3: '',
+ NS.Vec4: '',
+ }[value]
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ return (
+ str(self),
+ NumberSize1D.to_name(self),
+ NumberSize1D.to_name(self),
+ NumberSize1D.to_icon(self),
+ i,
+ )
+
+ @staticmethod
+ def has_shape(shape: tuple[int, ...] | None):
+ return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)]
+
+ def supports_shape(self, shape: tuple[int, ...] | None):
+ NS = NumberSize1D
+ match self:
+ case NS.Scalar:
+ return shape is None
+ case NS.Vec2:
+ return shape in ((2,), (2, 1))
+ case NS.Vec3:
+ return shape in ((3,), (3, 1))
+ case NS.Vec4:
+ return shape in ((4,), (4, 1))
+
+ @staticmethod
+ def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self:
+ NS = NumberSize1D
+ return {
+ None: NS.Scalar,
+ (2,): NS.Vec2,
+ (3,): NS.Vec3,
+ (4,): NS.Vec4,
+ (2, 1): NS.Vec2,
+ (3, 1): NS.Vec3,
+ (4, 1): NS.Vec4,
+ }[shape]
+
+ @property
+ def rows(self):
+ NS = NumberSize1D
+ return {
+ NS.Scalar: 1,
+ NS.Vec2: 2,
+ NS.Vec3: 3,
+ NS.Vec4: 4,
+ }[self]
+
+ @property
+ def cols(self):
+ return 1
+
+ @property
+ def shape(self):
+ NS = NumberSize1D
+ return {
+ NS.Scalar: None,
+ NS.Vec2: (2,),
+ NS.Vec3: (3,),
+ NS.Vec4: (4,),
+ }[self]
+
+
+def symbol_range(sym: sp.Symbol) -> str:
+ return f'{sym.name} ∈ ' + (
+ 'ℂ'
+ if sym.is_complex
+ else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?'))
+ )
+
+
+####################
+# - Symbol Sizes
+####################
+class NumberSize2D(enum.StrEnum):
+ """Simple subset of sizes for rank-2 tensors."""
+
+ Scalar = enum.auto()
+
+ # Vectors
+ Vec2 = enum.auto() ## 2x1
+ Vec3 = enum.auto() ## 3x1
+ Vec4 = enum.auto() ## 4x1
+
+ # Covectors
+ CoVec2 = enum.auto() ## 1x2
+ CoVec3 = enum.auto() ## 1x3
+ CoVec4 = enum.auto() ## 1x4
+
+ # Square Matrices
+ Mat22 = enum.auto() ## 2x2
+ Mat33 = enum.auto() ## 3x3
+ Mat44 = enum.auto() ## 4x4
diff --git a/src/blender_maxwell/utils/sympy_extra/parse_cast.py b/src/blender_maxwell/utils/sympy_extra/parse_cast.py
new file mode 100644
index 0000000..37672f9
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/parse_cast.py
@@ -0,0 +1,119 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import jax
+import jax.numpy as jnp
+import sympy as sp
+
+from .. import logger
+from .sympy_type import SympyType
+
+log = logger.get(__name__)
+
+
+####################
+# - Parsing: Info from SympyType
+####################
+def parse_shape(sp_obj: SympyType) -> int | None:
+ if isinstance(sp_obj, sp.MatrixBase):
+ return sp_obj.shape
+
+ return None
+
+
+####################
+# - Casting: Python
+####################
+def sympy_to_python(
+ scalar: sp.Basic, use_jax_array: bool = False
+) -> int | float | complex | tuple | jax.Array:
+ """Convert a scalar sympy expression to the directly corresponding Python type.
+
+ Arguments:
+ scalar: A sympy expression that has no symbols, but is expressed as a Sympy type.
+ For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter.
+
+ Returns:
+ A pure Python type that directly corresponds to the input scalar expression.
+ """
+ if isinstance(scalar, sp.MatrixBase):
+ # Detect Single Column Vector
+ ## --> Flatten to Single Row Vector
+ if len(scalar.shape) == 2 and scalar.shape[1] == 1:
+ _scalar = scalar.T
+ else:
+ _scalar = scalar
+
+ # Convert to Tuple of Tuples
+ matrix = tuple(
+ [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()]
+ )
+
+ # Detect Single Row Vector
+ ## --> This could be because the scalar had it.
+ ## --> This could also be because we flattened a column vector.
+ ## Either way, we should strip the pointless dimensions.
+ if len(matrix) == 1:
+ return matrix[0] if not use_jax_array else jnp.array(matrix[0])
+
+ return matrix if not use_jax_array else jnp.array(matrix)
+ if scalar.is_integer:
+ return int(scalar)
+ if scalar.is_rational or scalar.is_real:
+ return float(scalar)
+ if scalar.is_complex:
+ return complex(scalar)
+
+ msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001
+ raise ValueError(msg)
+
+
+####################
+# - Casting: Printing
+####################
+_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter(
+ settings={
+ 'abbrev': True,
+ }
+)
+
+
+def sp_to_str(sp_obj: SympyType) -> str:
+ """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter.
+
+ This should be used whenever a **string for UI use** is needed from a `sympy` object.
+
+ Notes:
+ This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression.
+ For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation.
+
+ Parameters:
+ sp_obj: The `sympy` object to convert to a string.
+
+ Returns:
+ A string representing the expression for human use.
+ _The string is not re-encodable to the expression._
+ """
+ ## TODO: A bool flag property that does a lot of find/replace to make it super pretty
+ return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
+
+
+def pretty_symbol(sym: sp.Symbol) -> str:
+ return f'{sym.name} ∈ ' + (
+ 'ℤ'
+ if sym.is_integer
+ else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?'))
+ )
diff --git a/src/blender_maxwell/utils/sympy_extra/physical_type.py b/src/blender_maxwell/utils/sympy_extra/physical_type.py
new file mode 100644
index 0000000..2adedda
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/physical_type.py
@@ -0,0 +1,644 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements `PhysicalType`, a convenient, UI-friendly way of deterministically handling the unit-dimensionality of arbitrary objects."""
+
+import enum
+import functools
+import typing as typ
+
+import sympy.physics.units as spu
+
+from blender_maxwell import contracts as ct
+
+from ..staticproperty import staticproperty
+from . import units as spux
+from .math_type import MathType
+from .sympy_expr import Unit
+from .sympy_type import SympyType
+from .unit_analysis import (
+ compare_unit_dim_to_unit_dim_deps,
+ compare_unit_dims,
+ unit_to_unit_dim_deps,
+)
+
+
+####################
+# - Unit Dimensions
+####################
+class DimsMeta(type):
+ """Metaclass allowing an implementing (ideally empty) class to access `spu.definitions.dimension_definitions` attributes directly via its own attribute."""
+
+ def __getattr__(cls, attr: str) -> spu.Dimension:
+ """Alias for `spu.definitions.dimension_definitions.*` (isn't that a mouthful?).
+
+ Raises:
+ AttributeError: If the name cannot be found.
+ """
+ if (
+ attr in spu.definitions.dimension_definitions.__dir__()
+ and not attr.startswith('__')
+ ):
+ return getattr(spu.definitions.dimension_definitions, attr)
+
+ raise AttributeError(name=attr, obj=Dims)
+
+
+class Dims(metaclass=DimsMeta):
+ """Access `sympy.physics.units` dimensions with less hassle.
+
+ Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`.
+
+ An `AttributeError` is raised if the unit cannot be found in `sympy`.
+
+ Examples:
+ The objects returned are a direct alias to `sympy`, with less hassle:
+ ```python
+ assert Dims.length == (
+ sympy.physics.units.definitions.dimension_definitions.length
+ )
+ ```
+ """
+
+
+####################
+# - Physical Type
+####################
+class PhysicalType(enum.StrEnum):
+ """An identifier of unit dimensionality with many useful properties."""
+
+ # Unitless
+ NonPhysical = enum.auto()
+
+ # Global
+ Time = enum.auto()
+ Angle = enum.auto()
+ SolidAngle = enum.auto()
+ ## TODO: Some kind of 3D-specific orientation ex. a quaternion
+ Freq = enum.auto()
+ AngFreq = enum.auto() ## rad*hertz
+ # Cartesian
+ Length = enum.auto()
+ Area = enum.auto()
+ Volume = enum.auto()
+ # Mechanical
+ Vel = enum.auto()
+ Accel = enum.auto()
+ Mass = enum.auto()
+ Force = enum.auto()
+ Pressure = enum.auto()
+ # Energy
+ Work = enum.auto() ## joule
+ Power = enum.auto() ## watt
+ PowerFlux = enum.auto() ## watt
+ Temp = enum.auto()
+ # Electrodynamics
+ Current = enum.auto() ## ampere
+ CurrentDensity = enum.auto()
+ Charge = enum.auto() ## coulomb
+ Voltage = enum.auto()
+ Capacitance = enum.auto() ## farad
+ Impedance = enum.auto() ## ohm
+ Conductance = enum.auto() ## siemens
+ Conductivity = enum.auto() ## siemens / length
+ MFlux = enum.auto() ## weber
+ MFluxDensity = enum.auto() ## tesla
+ Inductance = enum.auto() ## henry
+ EField = enum.auto()
+ HField = enum.auto()
+ # Luminal
+ LumIntensity = enum.auto()
+ LumFlux = enum.auto()
+ Illuminance = enum.auto()
+
+ ####################
+ # - Unit Dimensions
+ ####################
+ @functools.cached_property
+ def unit_dim(self) -> SympyType:
+ """The unit dimension expression associated with the `PhysicalType`.
+
+ A `PhysicalType` is, in its essence, merely an identifier for a particular unit dimension expression.
+ """
+ PT = PhysicalType
+ return {
+ PT.NonPhysical: None,
+ # Global
+ PT.Time: Dims.time,
+ PT.Angle: Dims.angle,
+ PT.SolidAngle: spu.steradian.dimension, ## MISSING
+ PT.Freq: Dims.frequency,
+ PT.AngFreq: Dims.angle * Dims.frequency,
+ # Cartesian
+ PT.Length: Dims.length,
+ PT.Area: Dims.length**2,
+ PT.Volume: Dims.length**3,
+ # Mechanical
+ PT.Vel: Dims.length / Dims.time,
+ PT.Accel: Dims.length / Dims.time**2,
+ PT.Mass: Dims.mass,
+ PT.Force: Dims.force,
+ PT.Pressure: Dims.pressure,
+ # Energy
+ PT.Work: Dims.energy,
+ PT.Power: Dims.power,
+ PT.PowerFlux: Dims.power / Dims.length**2,
+ PT.Temp: Dims.temperature,
+ # Electrodynamics
+ PT.Current: Dims.current,
+ PT.CurrentDensity: Dims.current / Dims.length**2,
+ PT.Charge: Dims.charge,
+ PT.Voltage: Dims.voltage,
+ PT.Capacitance: Dims.capacitance,
+ PT.Impedance: Dims.impedance,
+ PT.Conductance: Dims.conductance,
+ PT.Conductivity: Dims.conductance / Dims.length,
+ PT.MFlux: Dims.magnetic_flux,
+ PT.MFluxDensity: Dims.magnetic_density,
+ PT.Inductance: Dims.inductance,
+ PT.EField: Dims.voltage / Dims.length,
+ PT.HField: Dims.current / Dims.length,
+ # Luminal
+ PT.LumIntensity: Dims.luminous_intensity,
+ PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
+ PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
+ }[self]
+
+ @staticproperty
+ def unit_dims() -> dict[typ.Self, SympyType]:
+ """All unit dimensions supported by all `PhysicalType`s."""
+ return {
+ physical_type: physical_type.unit_dim
+ for physical_type in list(PhysicalType)
+ }
+
+ ####################
+ # - Convenience Properties
+ ####################
+ @functools.cached_property
+ def default_unit(self) -> list[Unit]:
+ """Subjective choice of 'default' unit from `self.valid_units`.
+
+ There is no requirement to use this.
+ """
+ PT = PhysicalType
+ return {
+ PT.NonPhysical: None,
+ # Global
+ PT.Time: spu.picosecond,
+ PT.Angle: spu.radian,
+ PT.SolidAngle: spu.steradian,
+ PT.Freq: spux.terahertz,
+ PT.AngFreq: spu.radian * spux.terahertz,
+ # Cartesian
+ PT.Length: spu.micrometer,
+ PT.Area: spu.um**2,
+ PT.Volume: spu.um**3,
+ # Mechanical
+ PT.Vel: spu.um / spu.second,
+ PT.Accel: spu.um / spu.second,
+ PT.Mass: spu.microgram,
+ PT.Force: spux.micronewton,
+ PT.Pressure: spux.millibar,
+ # Energy
+ PT.Work: spu.joule,
+ PT.Power: spu.watt,
+ PT.PowerFlux: spu.watt / spu.meter**2,
+ PT.Temp: spu.kelvin,
+ # Electrodynamics
+ PT.Current: spu.ampere,
+ PT.CurrentDensity: spu.ampere / spu.meter**2,
+ PT.Charge: spu.coulomb,
+ PT.Voltage: spu.volt,
+ PT.Capacitance: spu.farad,
+ PT.Impedance: spu.ohm,
+ PT.Conductance: spu.siemens,
+ PT.Conductivity: spu.siemens / spu.micrometer,
+ PT.MFlux: spu.weber,
+ PT.MFluxDensity: spu.tesla,
+ PT.Inductance: spu.henry,
+ PT.EField: spu.volt / spu.micrometer,
+ PT.HField: spu.ampere / spu.micrometer,
+ # Luminal
+ PT.LumIntensity: spu.candela,
+ PT.LumFlux: spu.candela * spu.steradian,
+ PT.Illuminance: spu.candela / spu.meter**2,
+ }[self]
+
+ ####################
+ # - Creation
+ ####################
+ @staticmethod
+ def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None:
+ """Attempt to determine a matching `PhysicalType` from a unit.
+
+ NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`.
+
+ Returns:
+ The matched `PhysicalType`.
+
+ If none could be matched, then either return `None` (if `optional` is set) or error.
+
+ Raises:
+ ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
+ """
+ if unit is None:
+ return PhysicalType.NonPhysical
+
+ ## TODO_ This enough?
+ if unit in [spu.radian, spu.degree]:
+ return PhysicalType.Angle
+
+ unit_dim_deps = unit_to_unit_dim_deps(unit)
+ if unit_dim_deps is not None:
+ for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
+ if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps):
+ return physical_type
+
+ if optional:
+ return None
+ msg = f'Could not determine PhysicalType for {unit}'
+ raise ValueError(msg)
+
+ @staticmethod
+ def from_unit_dim(
+ unit_dim: SympyType | None, optional: bool = False
+ ) -> typ.Self | None:
+ """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`.
+
+ For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`).
+ To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions.
+
+ Returns:
+ The matched `PhysicalType`.
+
+ If none could be matched, then either return `None` (if `optional` is set) or error.
+
+ Raises:
+ ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
+ """
+ for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
+ if compare_unit_dims(unit_dim, candidate_unit_dim):
+ return physical_type
+
+ if optional:
+ return None
+ msg = f'Could not determine PhysicalType for {unit_dim}'
+ raise ValueError(msg)
+
+ ####################
+ # - Valid Properties
+ ####################
+ @functools.cached_property
+ def valid_units(self) -> list[Unit]:
+ """Retrieve an ordered (by subjective usefulness) list of units for this physical type.
+
+ Warnings:
+ **Altering the order of units hard-breaks backwards compatibility**, since enums based on it only keep an integer index.
+
+ Notes:
+ The order in which valid units are declared is the exact same order that UI dropdowns display them.
+ """
+ PT = PhysicalType
+ return {
+ PT.NonPhysical: [None],
+ # Global
+ PT.Time: [
+ spu.picosecond,
+ spux.femtosecond,
+ spu.nanosecond,
+ spu.microsecond,
+ spu.millisecond,
+ spu.second,
+ spu.minute,
+ spu.hour,
+ spu.day,
+ ],
+ PT.Angle: [
+ spu.radian,
+ spu.degree,
+ ],
+ PT.SolidAngle: [
+ spu.steradian,
+ ],
+ PT.Freq: (
+ _valid_freqs := [
+ spux.terahertz,
+ spu.hertz,
+ spux.kilohertz,
+ spux.megahertz,
+ spux.gigahertz,
+ spux.petahertz,
+ spux.exahertz,
+ ]
+ ),
+ PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs],
+ # Cartesian
+ PT.Length: (
+ _valid_lens := [
+ spu.micrometer,
+ spu.nanometer,
+ spu.picometer,
+ spu.angstrom,
+ spu.millimeter,
+ spu.centimeter,
+ spu.meter,
+ spu.inch,
+ spu.foot,
+ spu.yard,
+ spu.mile,
+ ]
+ ),
+ PT.Area: [_unit**2 for _unit in _valid_lens],
+ PT.Volume: [_unit**3 for _unit in _valid_lens],
+ # Mechanical
+ PT.Vel: [_unit / spu.second for _unit in _valid_lens],
+ PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
+ PT.Mass: [
+ spu.kilogram,
+ spu.electron_rest_mass,
+ spu.dalton,
+ spu.microgram,
+ spu.milligram,
+ spu.gram,
+ spu.metric_ton,
+ ],
+ PT.Force: [
+ spux.micronewton,
+ spux.nanonewton,
+ spux.millinewton,
+ spu.newton,
+ spu.kg * spu.meter / spu.second**2,
+ ],
+ PT.Pressure: [
+ spu.bar,
+ spux.millibar,
+ spu.pascal,
+ spux.hectopascal,
+ spu.atmosphere,
+ spu.psi,
+ spu.mmHg,
+ spu.torr,
+ ],
+ # Energy
+ PT.Work: [
+ spu.joule,
+ spu.electronvolt,
+ ],
+ PT.Power: [
+ spu.watt,
+ ],
+ PT.PowerFlux: [
+ spu.watt / spu.meter**2,
+ ],
+ PT.Temp: [
+ spu.kelvin,
+ ],
+ # Electrodynamics
+ PT.Current: [
+ spu.ampere,
+ ],
+ PT.CurrentDensity: [
+ spu.ampere / spu.meter**2,
+ ],
+ PT.Charge: [
+ spu.coulomb,
+ ],
+ PT.Voltage: [
+ spu.volt,
+ ],
+ PT.Capacitance: [
+ spu.farad,
+ ],
+ PT.Impedance: [
+ spu.ohm,
+ ],
+ PT.Conductance: [
+ spu.siemens,
+ ],
+ PT.Conductivity: [
+ spu.siemens / spu.micrometer,
+ spu.siemens / spu.meter,
+ ],
+ PT.MFlux: [
+ spu.weber,
+ ],
+ PT.MFluxDensity: [
+ spu.tesla,
+ ],
+ PT.Inductance: [
+ spu.henry,
+ ],
+ PT.EField: [
+ spu.volt / spu.micrometer,
+ spu.volt / spu.meter,
+ ],
+ PT.HField: [
+ spu.ampere / spu.micrometer,
+ spu.ampere / spu.meter,
+ ],
+ # Luminal
+ PT.LumIntensity: [
+ spu.candela,
+ ],
+ PT.LumFlux: [
+ spu.candela * spu.steradian,
+ ],
+ PT.Illuminance: [
+ spu.candela / spu.meter**2,
+ ],
+ }[self]
+
+ @functools.cached_property
+ def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]:
+ """All shapes with physical meaning in the context of a particular unit dimension."""
+ PT = PhysicalType
+ overrides = {
+ # Cartesian
+ PT.Length: [None, (2,), (3,)],
+ # Mechanical
+ PT.Vel: [None, (2,), (3,)],
+ PT.Accel: [None, (2,), (3,)],
+ PT.Force: [None, (2,), (3,)],
+ # Energy
+ PT.Work: [None, (2,), (3,)],
+ PT.PowerFlux: [None, (2,), (3,)],
+ # Electrodynamics
+ PT.CurrentDensity: [None, (2,), (3,)],
+ PT.MFluxDensity: [None, (2,), (3,)],
+ PT.EField: [None, (2,), (3,)],
+ PT.HField: [None, (2,), (3,)],
+ # Luminal
+ PT.LumFlux: [None, (2,), (3,)],
+ }
+
+ return overrides.get(self, [None])
+
+ @functools.cached_property
+ def valid_mathtypes(self) -> list[MathType]:
+ """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued.
+
+ Generally, all unit quantities are real, in the algebraic mathematical sense.
+ However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening.
+ This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems.
+ In general, the value is a phasor.
+
+ While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply.
+
+ Notes:
+ - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation.
+ - **Current**/**Voltage**: The imaginary part represents the phase.
+ This also holds for any downstream units.
+ - **Charge**: Generally, it is real.
+ However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers:
+ - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
+
+ """
+ MT = MathType
+ PT = PhysicalType
+ overrides = {
+ PT.NonPhysical: list(MT), ## Support All
+ # Cartesian
+ PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
+ PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
+ # Mechanical
+ # Energy
+ # Electrodynamics
+ PT.Current: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance
+ PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction
+ PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction
+ PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction
+ PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
+ PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
+ # Luminal
+ }
+
+ return overrides.get(self, [MT.Real])
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name for a physical type."""
+ if value is PhysicalType.NonPhysical:
+ return 'Unitless'
+ return PhysicalType(value).name
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ """No icons."""
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
+ PT = PhysicalType
+ return (
+ str(self),
+ PT.to_name(self),
+ PT.to_name(self),
+ PT.to_icon(self),
+ i,
+ )
+
+ @functools.cached_property
+ def color(self):
+ """A color corresponding to the physical type.
+
+ The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented.
+ The LLM provided the following rationale for its choices:
+
+ > Non-Physical: Grey signifies neutrality and non-physical nature.
+ > Global:
+ > Time: Blue is often associated with calmness and the passage of time.
+ > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects.
+ > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time.
+ > Cartesian:
+ > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension.
+ > Mechanical:
+ > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities.
+ > Mass: Dark red for the fundamental property.
+ > Force and Pressure: Shades of red indicating intensity.
+ > Energy:
+ > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities.
+ > Temperature: Yellow for heat.
+ > Electrodynamics:
+ > Current and related quantities: Cyan shades indicating flow.
+ > Voltage, Capacitance: Greenish and blueish cyan for electrical potential.
+ > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance.
+ > Magnetic properties: Magenta shades for magnetism.
+ > Electric Field: Light blue.
+ > Magnetic Field: Grey, as it can be considered neutral in terms of direction.
+ > Luminal:
+ > Luminous properties: Yellows to signify light and illumination.
+ >
+ > This color mapping helps maintain intuitive connections for users interacting with these physical types.
+ """
+ PT = PhysicalType
+ return {
+ PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical
+ # Global
+ PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time
+ PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle
+ PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle
+ PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency
+ PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency
+ # Cartesian
+ PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length
+ PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area
+ PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume
+ # Mechanical
+ PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity
+ PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration
+ PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass
+ PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force
+ PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure
+ # Energy
+ PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work
+ PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power
+ PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux
+ PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature
+ # Electrodynamics
+ PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current
+ PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density
+ PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge
+ PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage
+ PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance
+ PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance
+ PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance
+ PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity
+ PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux
+ PT.MFluxDensity: (
+ 0.85,
+ 0.5,
+ 0.85,
+ 1.0,
+ ), # Light Magenta: Magnetic Flux Density
+ PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance
+ PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field
+ PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field
+ # Luminal
+ PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity
+ PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux
+ PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance
+ }[self]
diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_expr.py b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py
new file mode 100644
index 0000000..cafd421
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py
@@ -0,0 +1,337 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import typing as typ
+from fractions import Fraction
+
+import pydantic as pyd
+import sympy as sp
+import sympy.physics.units as spu
+import typing_extensions as typx
+from pydantic_core import core_schema as pyd_core_schema
+
+from . import units as spux
+from .sympy_type import SympyType
+from .unit_analysis import get_units, uses_units
+
+
+####################
+# - Pydantic "Sympy Expr"
+####################
+class _SympyExpr:
+ """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`.
+
+ Notes:
+ You probably want to use `SympyExpr`.
+
+ Examples:
+ To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`:
+
+ ```python
+ SympyExpr = typx.Annotated[SympyType, _SympyExpr]
+
+ class Spam(pyd.BaseModel):
+ line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3)
+ ```
+ """
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls,
+ _source_type: SympyType,
+ _handler: pyd.GetCoreSchemaHandler,
+ ) -> pyd_core_schema.CoreSchema:
+ """Compute a schema that allows `pydantic` to validate a `sympy` type."""
+
+ def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any:
+ """Parse and validate a string expression.
+
+ Parameters:
+ sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
+ Before use, `isinstance(expr_str, str)` is checked.
+ If the object isn't a string, then the validation will be skipped.
+
+ Returns:
+ Either a `sympy` object, if the input is parseable, or the same untouched object.
+
+ Raises:
+ ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
+ """
+ # Constrain to String
+ if not isinstance(sp_str, str):
+ return sp_str
+
+ # Parse String -> Sympy
+ try:
+ expr = sp.sympify(sp_str)
+ except ValueError as ex:
+ msg = f'String {sp_str} is not a valid sympy expression'
+ raise ValueError(msg) from ex
+
+ # Substitute Symbol -> Quantity
+ return expr.subs(spux.UNIT_BY_SYMBOL)
+
+ def validate_from_pytype(
+ sp_pytype: int | Fraction | float | complex,
+ ) -> SympyType | typ.Any:
+ """Parse and validate a pure Python type.
+
+ Parameters:
+ sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
+ Before use, `isinstance(expr_str, str)` is checked.
+ If the object isn't a string, then the validation will be skipped.
+
+ Returns:
+ Either a `sympy` object, if the input is parseable, or the same untouched object.
+
+ Raises:
+ ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
+ """
+ # Constrain to String
+ if not isinstance(sp_pytype, int | Fraction | float | complex):
+ return sp_pytype
+
+ if isinstance(sp_pytype, int):
+ return sp.Integer(sp_pytype)
+ if isinstance(sp_pytype, Fraction):
+ return sp.Rational(sp_pytype.numerator, sp_pytype.denominator)
+ if isinstance(sp_pytype, float):
+ return sp.Float(sp_pytype)
+
+ # sp_pytype => Complex
+ return sp_pytype.real + sp.I * sp_pytype.imag
+
+ sympy_expr_schema = pyd_core_schema.chain_schema(
+ [
+ pyd_core_schema.no_info_plain_validator_function(validate_from_str),
+ pyd_core_schema.no_info_plain_validator_function(validate_from_pytype),
+ pyd_core_schema.is_instance_schema(SympyType),
+ ]
+ )
+ return pyd_core_schema.json_or_python_schema(
+ json_schema=sympy_expr_schema,
+ python_schema=sympy_expr_schema,
+ serialization=pyd_core_schema.plain_serializer_function_ser_schema(
+ lambda sp_obj: sp.srepr(sp_obj)
+ ),
+ )
+
+
+SympyExpr = typx.Annotated[
+ sp.Basic, ## Treat all sympy types as sp.Basic
+ _SympyExpr,
+]
+## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate.
+
+
+def ConstrSympyExpr( # noqa: N802, PLR0913
+ # Features
+ allow_variables: bool = True,
+ allow_units: bool = True,
+ # Structures
+ allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
+ | None = None,
+ allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
+ # Element Class
+ max_symbols: int | None = None,
+ allowed_symbols: set[sp.Symbol] | None = None,
+ allowed_units: set[spu.Quantity] | None = None,
+ # Shape Class
+ allowed_matrix_shapes: set[tuple[int, int]] | None = None,
+) -> SympyType:
+ """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`.
+
+ Relies on the `sympy` assumptions system.
+ See
+
+ Parameters (TBD):
+
+ Returns:
+ A type that represents a constrained `sympy` expression.
+ """
+
+ def validate_expr(expr: SympyType):
+ if not (isinstance(expr, SympyType),):
+ msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})"
+ raise ValueError(msg)
+
+ msgs = set()
+
+ # Validate Feature Class
+ if (not allow_variables) and (len(expr.free_symbols) > 0):
+ msgs.add(
+ f'allow_variables={allow_variables} does not match expression {expr}.'
+ )
+ if (not allow_units) and uses_units(expr):
+ msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
+
+ # Validate Structure Class
+ if (
+ allowed_sets
+ and isinstance(expr, sp.Expr)
+ and not any(
+ {
+ 'integer': expr.is_integer,
+ 'rational': expr.is_rational,
+ 'real': expr.is_real,
+ 'complex': expr.is_complex,
+ }[allowed_set]
+ for allowed_set in allowed_sets
+ )
+ ):
+ msgs.add(
+ f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
+ )
+ if allowed_structures and not any(
+ {
+ 'scalar': True,
+ 'matrix': isinstance(expr, sp.MatrixBase),
+ }[allowed_set]
+ for allowed_set in allowed_structures
+ ):
+ msgs.add(
+ f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
+ )
+
+ # Validate Element Class
+ if max_symbols and len(expr.free_symbols) > max_symbols:
+ msgs.add(f'max_symbols={max_symbols} does not match expression {expr}')
+ if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
+ msgs.add(
+ f'allowed_symbols={allowed_symbols} does not match expression {expr}'
+ )
+ if allowed_units and get_units(expr).issubset(allowed_units):
+ msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
+
+ # Validate Shape Class
+ if (
+ allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
+ ) and expr.shape not in allowed_matrix_shapes:
+ msgs.add(
+ f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
+ )
+
+ # Error or Return
+ if msgs:
+ raise ValueError(str(msgs))
+ return expr
+
+ return typx.Annotated[
+ sp.Basic,
+ _SympyExpr,
+ pyd.AfterValidator(validate_expr),
+ ]
+
+
+####################
+# - Common ConstrSympyExpr
+####################
+# Expression
+ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_structures={'scalar'},
+ allowed_sets={'integer', 'rational', 'real'},
+)
+ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_structures={'scalar'},
+ allowed_sets={'integer', 'rational', 'real', 'complex'},
+)
+
+# Symbol
+IntSymbol: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=True,
+ allow_units=False,
+ allowed_sets={'integer'},
+ max_symbols=1,
+)
+RationalSymbol: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=True,
+ allow_units=False,
+ allowed_sets={'integer', 'rational'},
+ max_symbols=1,
+)
+RealSymbol: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=True,
+ allow_units=False,
+ allowed_sets={'integer', 'rational', 'real'},
+ max_symbols=1,
+)
+ComplexSymbol: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=True,
+ allow_units=False,
+ allowed_sets={'integer', 'rational', 'real', 'complex'},
+ max_symbols=1,
+)
+Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol
+
+# Unit
+UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension
+
+## Technically a "unit expression", which includes compound types.
+## Support for this is the reason to prefer over raw spu.Quantity.
+Unit: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=True,
+ allowed_structures={'scalar'},
+)
+
+# Number
+IntNumber: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_sets={'integer'},
+ allowed_structures={'scalar'},
+)
+RealNumber: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_sets={'integer', 'rational', 'real'},
+ allowed_structures={'scalar'},
+)
+ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_sets={'integer', 'rational', 'real', 'complex'},
+ allowed_structures={'scalar'},
+)
+Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
+
+# Number
+PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=True,
+ allowed_sets={'integer', 'rational', 'real'},
+ allowed_structures={'scalar'},
+)
+PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=True,
+ allowed_sets={'integer', 'rational', 'real', 'complex'},
+ allowed_structures={'scalar'},
+)
+PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
+
+# Vector
+Real3DVector: typ.TypeAlias = ConstrSympyExpr(
+ allow_variables=False,
+ allow_units=False,
+ allowed_sets={'integer', 'rational', 'real'},
+ allowed_structures={'matrix'},
+ allowed_matrix_shapes={(3, 1)},
+)
diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_type.py b/src/blender_maxwell/utils/sympy_extra/sympy_type.py
new file mode 100644
index 0000000..ecb736e
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/sympy_type.py
@@ -0,0 +1,23 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import sympy as sp
+import sympy.physics.units as spu
+
+####################
+# - Underlying "Sympy Type"
+####################
+SympyType = sp.Basic | sp.MatrixBase | spu.Quantity | spu.Dimension
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py
new file mode 100644
index 0000000..3234407
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py
@@ -0,0 +1,287 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Functions for characterizaiton, conversion and casting of `sympy` objects that use units."""
+
+import functools
+
+import sympy as sp
+import sympy.physics.units as spu
+
+from . import units as spux
+from .parse_cast import sympy_to_python
+from .sympy_type import SympyType
+
+
+####################
+# - Unit Characterization
+####################
+## TODO: Caching w/srepr'ed expression.
+## TODO: An LFU cache could do better than an LRU.
+def uses_units(sp_obj: SympyType) -> bool:
+ """Determines if an expression uses any units.
+
+ Parameters:
+ expr: The sympy object that may contain units.
+
+ Returns:
+ Whether or not units are present in the object.
+ """
+ return sp_obj.has(spu.Quantity)
+
+
+## TODO: Caching w/srepr'ed expression.
+## TODO: An LFU cache could do better than an LRU.
+def get_units(expr: sp.Expr) -> set[spu.Quantity]:
+ """Finds all units used by the expression, and returns them as a set.
+
+ No information about _the relationship between units_ is exposed.
+ For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`.
+
+
+ Notes:
+ The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
+
+ The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node.
+
+ Parameters:
+ expr: The sympy expression that may contain units.
+
+ Returns:
+ All units (`spu.Quantity`) used within the expression.
+ """
+ return {
+ subexpr
+ for subexpr in sp.postorder_traversal(expr)
+ if isinstance(subexpr, spu.Quantity)
+ }
+
+
+####################
+# - Dimensional Characterization
+####################
+def unit_dim_to_unit_dim_deps(
+ unit_dims: SympyType,
+) -> dict[spu.dimensions.Dimension, int] | None:
+ """Normalize an expression to a mapping of its dimensional dependencies.
+
+ Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent.
+
+ Notes:
+ We adhere to SI unit conventions when determining dimensional dependencies, to ensure that ex. `freq -> 1/time` equivalences are normalized away.
+ This allows the output of this method to be compared meaningfully, to determine whether two dimensional expressions are equivalent.
+
+ We choose to catch a `TypeError`, for cases where dimensional analysis is impossible (especially `+` or `-` between differing dimensions).
+ This may have a slight performance penalty.
+
+ Returns:
+ The dimensional dependencies of the dimensional expression.
+
+ If such a thing makes no sense, ex. if `+` or `-` is present between differing unit dimensions, then return None.
+ """
+ dimsys_SI = spu.systems.si.dimsys_SI
+
+ # Retrieve Dimensional Dependencies
+ try:
+ return dimsys_SI.get_dimensional_dependencies(unit_dims)
+
+ # Catch TypeError
+ ## -> Happens if `+` or `-` is in `unit`.
+ ## -> Generally, it doesn't make sense to add/subtract differing unit dims.
+ ## -> Thus, when trying to figure out the unit dimension, there isn't one.
+ except TypeError:
+ return None
+
+
+def unit_to_unit_dim_deps(
+ unit: SympyType,
+) -> dict[spu.dimensions.Dimension, int] | None:
+ """Deduce the dimensional dependencies of a unit.
+
+ Notes:
+ Using `.subs()` to replace `sp.Quantity`s with `spu.dimensions.Dimension`s seems to result in an expression that absolutely refuses to claim that it has anything other than raw `sp.Symbol`s.
+
+ This is extremely problematic - dimensional analysis relies on the arithmetic properties of proper `Dimension` objects.
+
+ For this reason, though we'd rather have a `unit_to_unit_dims()` function, we have not yet found a way to do this.
+ Luckily, most of our uses cases seem only to require the dimensional dictionary, which (surprisingly) seems accessible using `unit_dim_to_unit_dim_deps()`.
+
+ """
+ # Retrieve Dimensional Dependencies
+ ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms.
+ ## -> This is extremely problematic; `Dims` arithmetic has key properties.
+ ## -> So we have to go all the way to the dimensional dependencies.
+ ## -> This isn't really respecting the args, but it seems to work :)
+ return unit_dim_to_unit_dim_deps(
+ unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)})
+ )
+
+
+def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
+ """Compare the dimensional dependencies of two unit dimensions.
+
+ Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent.
+ """
+ return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps(
+ unit_dim_r
+ )
+
+
+def compare_units_by_unit_dims(unit_l: SympyType, unit_r: SympyType) -> bool:
+ """Compare two units by their unit dimensions."""
+ return unit_to_unit_dim_deps(unit_l) == unit_to_unit_dim_deps(unit_r)
+
+
+def compare_unit_dim_to_unit_dim_deps(
+ unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int]
+) -> bool:
+ """Compare the dimensional dependencies of unit dimensions to pre-defined unit dimensions."""
+ return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps
+
+
+####################
+# - Unit Casting
+####################
+def strip_units(sp_obj: SympyType) -> SympyType:
+ """Strip all units by replacing them to `1`.
+
+ This is a rather unsafe method.
+ You probably shouldn't use it.
+
+ Warnings:
+ Absolutely no effort is made to determine whether stripping units is a _meaningful thing to do_.
+
+ For example, using `+` expressions of compatible dimension, but different units, is a clear mistake.
+ For example, `8*meter + 9*millimeter` strips to `8(1) + 9(1) = 17`, which is a garbage result.
+
+ The **user of this method** must themselves perform appropriate checks on th eobject before stripping units.
+
+ Parameters:
+ sp_obj: A sympy object that contains unit symbols.
+ **NOTE**: Unit symbols (from `sympy.physics.units`) are not _free_ symbols, in that they are not unknown.
+ Nonetheless, they are not _numbers_ either, and thus they cannot be used in a numerical expression.
+
+ Returns:
+ The sympy object with all unit symbols replaced by `1`, effectively extracting the unitless part of the object.
+ """
+ return sp_obj.subs(spux.UNIT_TO_1)
+
+
+def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType:
+ """Convert a sympy object to the given unit.
+
+ Supports a unit of `None`, which simply causes the object to have its units stripped.
+ """
+ if unit is None:
+ return strip_units(sp_obj)
+ return spu.convert_to(sp_obj, unit)
+
+ # msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"'
+ # raise ValueError(msg)
+
+
+## TODO: Include sympy_to_python in 'scale_to' to match semantics of 'scale_to_unit_system'
+## -- Introduce a 'strip_unit
+def scale_to_unit(
+ sp_obj: SympyType,
+ unit: spu.Quantity | None,
+ cast_to_pytype: bool = False,
+ use_jax_array: bool = False,
+) -> SympyType:
+ """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value.
+
+ This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**.
+
+ Notes:
+ The unitless output is still an `sp.Expr`, which may contain ex. symbols.
+
+ If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type.
+ In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem.
+
+ Parameters:
+ expr: The unit-containing expression to convert.
+ unit_to: The unit that is converted to.
+
+ Returns:
+ The unitless part of `expr`, after scaling the entire expression to `unit`.
+
+ Raises:
+ ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`.
+ """
+ sp_obj_stripped = strip_units(convert_to_unit(sp_obj, unit))
+ if cast_to_pytype:
+ return sympy_to_python(
+ sp_obj_stripped,
+ use_jax_array=use_jax_array,
+ )
+ return sp_obj_stripped
+
+
+def scaling_factor(
+ unit_from: SympyType, unit_to: SympyType
+) -> int | float | complex | tuple | None:
+ """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another.
+
+ Parameters:
+ unit_from: The unit that is converted from.
+ unit_to: The unit that is converted to.
+
+ Returns:
+ The numerical scaling factor between the two units.
+
+ If the units are incompatible, then we return None.
+
+ Raises:
+ ValueError: If the two units don't share a common dimension.
+ """
+ if compare_units_by_unit_dims(unit_from, unit_to):
+ return scale_to_unit(unit_from, unit_to)
+ return None
+
+
+@functools.cache
+def unit_str_to_unit(unit_str: str, optional: bool = False) -> SympyType | None:
+ """Determine the `sympy` unit expression that matches the given unit string.
+
+ Parameters:
+ unit_str: A string parseable with `sp.sympify`, which contains a unit expression.
+ optional: Whether to return
+ **NOTE**: `None` is itself a valid "unit", denoting dimensionlessness, in general.
+ Ensure that appropriate checks are performed to account for this nuance.
+
+ Returns:
+ The matching `sympy` unit.
+
+ Raises:
+ ValueError: When no valid unit can be matched to the unit string, and `optional` is `False`.
+ """
+ match unit_str:
+ # Special-Case 'degree'
+ ## -> sp.sympify('degree') produces the sp.degree().
+ ## -> TODO: Proper Analysis analysis.
+ case 'degree':
+ unit = spu.degree
+
+ case _:
+ unit = sp.sympify(unit_str).subs(spux.UNIT_BY_SYMBOL)
+
+ if uses_units(unit):
+ return unit
+
+ if optional:
+ return None
+ msg = f'No valid unit for unit string {unit_str}'
+ raise ValueError(msg)
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py
new file mode 100644
index 0000000..cb30375
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py
@@ -0,0 +1,93 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Functions for conversion and casting of `sympy` objects that use units, via unit systems."""
+
+import jax
+import sympy.physics.units as spu
+
+from . import units as spux
+from .parse_cast import sympy_to_python
+from .physical_type import PhysicalType
+from .sympy_type import SympyType
+from .unit_analysis import get_units
+from .unit_systems import UnitSystem
+
+
+####################
+# - Conversion
+####################
+def strip_unit_system(
+ sp_obj: SympyType, unit_system: UnitSystem | None = None
+) -> SympyType:
+ """Strip units occurring in the given unit system from the expression.
+
+ Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
+ Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_.
+
+ Notes:
+ You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
+ """
+ if unit_system is None:
+ return sp_obj.subs(spux.UNIT_TO_1)
+
+ return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
+
+
+def convert_to_unit_system(
+ sp_obj: SympyType, unit_system: UnitSystem | None
+) -> SympyType:
+ """Convert an expression to the units of a given unit system."""
+ if unit_system is None:
+ return sp_obj
+
+ return spu.convert_to(
+ sp_obj,
+ {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
+ )
+
+
+####################
+# - Casting
+####################
+def scale_to_unit_system(
+ sp_obj: SympyType,
+ unit_system: UnitSystem | None,
+ use_jax_array: bool = False,
+) -> int | float | complex | tuple | jax.Array:
+ """Convert an expression to the units of a given unit system, then strip all units of the unit system.
+
+ Afterwards, it is converted to an appropriate Python type.
+
+ Notes:
+ For stability, and performance, reasons, this should only be used at the very last stage.
+
+ Regarding performance: **This is not a fast function**.
+
+ Parameters:
+ sp_obj: An arbitrary sympy object, presumably with units.
+ unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units.
+ Note that, in this context, only `unit_system.values()` is used.
+
+ Returns:
+ An appropriate pure Python type, after scaling to the unit system and stripping all units away.
+
+ If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`.
+ """
+ return sympy_to_python(
+ strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system),
+ use_jax_array=use_jax_array,
+ )
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_systems.py b/src/blender_maxwell/utils/sympy_extra/unit_systems.py
new file mode 100644
index 0000000..1de298a
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/unit_systems.py
@@ -0,0 +1,80 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Defines a common unit system representation, as well as a few of the most common / useful unit systems.
+
+Attributes:
+ UnitSystem: Type of a unit system representation, as an exhaustive mapping from `PhysicalType` to a unit expression.
+ **Compatibility between `PhysicalType` and unit must be manually guaranteed** when defining new unit systems.
+ UNITS_SI: Pre-defined go-to choice of unit system, which can also be a useful base to build other unit systems on.
+"""
+
+import typing as typ
+
+import sympy.physics.units as spu
+
+from . import units as spux
+from .physical_type import PhysicalType as PT # noqa: N817
+from .sympy_expr import Unit
+
+####################
+# - Unit System Representation
+####################
+UnitSystem: typ.TypeAlias = dict[PT, Unit]
+
+####################
+# - Standard Unit Systems
+####################
+UNITS_SI: UnitSystem = {
+ PT.NonPhysical: None,
+ # Global
+ PT.Time: spu.second,
+ PT.Angle: spu.radian,
+ PT.SolidAngle: spu.steradian,
+ PT.Freq: spu.hertz,
+ PT.AngFreq: spu.radian * spu.hertz,
+ # Cartesian
+ PT.Length: spu.meter,
+ PT.Area: spu.meter**2,
+ PT.Volume: spu.meter**3,
+ # Mechanical
+ PT.Vel: spu.meter / spu.second,
+ PT.Accel: spu.meter / spu.second**2,
+ PT.Mass: spu.kilogram,
+ PT.Force: spu.newton,
+ # Energy
+ PT.Work: spu.joule,
+ PT.Power: spu.watt,
+ PT.PowerFlux: spu.watt / spu.meter**2,
+ PT.Temp: spu.kelvin,
+ # Electrodynamics
+ PT.Current: spu.ampere,
+ PT.CurrentDensity: spu.ampere / spu.meter**2,
+ PT.Voltage: spu.volt,
+ PT.Capacitance: spu.farad,
+ PT.Impedance: spu.ohm,
+ PT.Conductance: spu.siemens,
+ PT.Conductivity: spu.siemens / spu.meter,
+ PT.MFlux: spu.weber,
+ PT.MFluxDensity: spu.tesla,
+ PT.Inductance: spu.henry,
+ PT.EField: spu.volt / spu.meter,
+ PT.HField: spu.ampere / spu.meter,
+ # Luminal
+ PT.LumIntensity: spu.candela,
+ PT.LumFlux: spux.lumen,
+ PT.Illuminance: spu.lux,
+}
diff --git a/src/blender_maxwell/utils/sympy_extra/units.py b/src/blender_maxwell/utils/sympy_extra/units.py
new file mode 100644
index 0000000..9ffd4cf
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/units.py
@@ -0,0 +1,77 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import typing as typ
+
+import sympy as sp
+import sympy.physics.units as spu
+
+####################
+# - Units
+####################
+# Time
+femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs')
+femtosecond.set_global_relative_scale_factor(spu.femto, spu.second)
+
+# Length
+femtometer = fm = spu.Quantity('femtometer', abbrev='fm')
+femtometer.set_global_relative_scale_factor(spu.femto, spu.meter)
+
+# Lum Flux
+lumen = lm = spu.Quantity('lumen', abbrev='lm')
+lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian)
+
+# Force
+nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816
+nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton)
+
+micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816
+micronewton.set_global_relative_scale_factor(spu.micro, spu.newton)
+
+millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816
+micronewton.set_global_relative_scale_factor(spu.milli, spu.newton)
+
+# Frequency
+kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz')
+kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
+
+megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz')
+kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
+
+gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz')
+gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz)
+
+terahertz = THz = spu.Quantity('terahertz', abbrev='THz')
+terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz)
+
+petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz')
+petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz)
+
+exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz')
+exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz)
+
+# Pressure
+millibar = mbar = spu.Quantity('millibar', abbrev='mbar')
+millibar.set_global_relative_scale_factor(spu.milli, spu.bar)
+
+hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816
+hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal)
+
+UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
+ unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
+} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
+
+UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}