diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py index 244e54b..b5def9c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py @@ -21,7 +21,7 @@ import typing as typ import bpy import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .socket_types import SocketType diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py index a6c2b54..1599f45 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py @@ -22,7 +22,7 @@ import numpy as np import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger log = logger.get(__name__) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py index 412cd05..c588fab 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py @@ -16,7 +16,7 @@ import typing as typ -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from . import FlowKind diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py index 4ff917e..d892ece 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py @@ -19,7 +19,7 @@ import functools import typing as typ from blender_maxwell.contracts import BLEnumElement -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from blender_maxwell.utils.staticproperty import staticproperty diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py index 43dbfe5..58eba20 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py @@ -18,8 +18,8 @@ import dataclasses import functools import typing as typ -from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .array import ArrayFlow from .lazy_range import RangeFlow @@ -89,6 +89,7 @@ class InfoFlow: return None def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None: + """Retrieve the dimension associated with a particular index.""" if idx > 0 and idx < len(self.dims) - 1: return list(self.dims.keys())[idx] return None @@ -179,15 +180,23 @@ class InfoFlow: While that sounds fancy and all, it boils down to: $$ - \texttt{dims} + |\texttt{output}.\texttt{shape}| + |\texttt{dims}| + |\texttt{output}.\texttt{shape}| $$ - Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly. + Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape. Notes: Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`. """ - return len(self.input_mathtypes) + self.output_shape_len + return len(self.dims) + self.output_shape_len + + @functools.cached_property + def is_scalar(self) -> tuple[spux.MathType, int, int]: + """Whether the described object can be described as "scalar". + + True when `self.order == 0`. + """ + return self.order == 0 #################### # - Properties @@ -204,6 +213,58 @@ class InfoFlow: for dim, dim_idx in self.dims.items() } + #################### + # - Operations: Comparison + #################### + def compare_dims_identical(self, other: typ.Self) -> bool: + """Whether that the quantity and properites of all dimension `SimSymbol`s are "identical". + + "Identical" is defined according to the semantics of `SimSymbol.compare()`, which generally means that everything but the exact name and unit are different. + """ + return len(self.dims) == len(other.dims) and all( + dim_l.compare(dim_r) + for dim_l, dim_r in zip(self.dims, other.dims, strict=True) + ) + + def compare_addable( + self, other: typ.Self, allow_differing_unit: bool = False + ) -> bool: + """Whether the two `InfoFlows` can be added/subtracted elementwise. + + Parameters: + allow_differing_unit: When set, + Forces the user to be explicit about specifying + """ + return self.compare_dims_identical(other) and self.output.compare_addable( + other.output, allow_differing_unit=allow_differing_unit + ) + + def compare_multiplicable(self, other: typ.Self) -> bool: + """Whether the two `InfoFlow`s can be multiplied (elementwise). + + - The output `SimSymbol`s must be able to be multiplied. + - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical. + """ + return self.output.compare_multiplicable(other.output) and ( + (len(self.dims) == 0 and self.output.shape_len == 0) + or (len(other.dims) == 0 and other.output.shape_len == 0) + or self.compare_dims_identical(other) + ) + + def compare_exponentiable(self, other: typ.Self) -> bool: + """Whether the two `InfoFlow`s can be exponentiated. + + In general, we follow the rules of the "Hadamard Power" operator, which is also in use in `numpy` broadcasting rules. + + - The output `SimSymbol`s must be able to be exponentiated (mainly, the exponent can't have a unit). + - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical. + """ + return self.output.compare_exponentiable(other.output) and ( + (len(self.dims) == 0 and self.output.shape_len == 0) + or (len(other.dims) == 0 and other.output.shape_len == 0) + or self.compare_dims_identical(other) + ) + #################### # - Operations: Dimensions #################### @@ -319,6 +380,7 @@ class InfoFlow: op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr], unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr], ) -> spux.SympyExpr: + """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`.""" sym_name = sim_symbols.SimSymbolName.Expr expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy) unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor) @@ -341,11 +403,11 @@ class InfoFlow: cols = self.output.cols match (rows, cols): case (1, 1): - new_output = self.output.set_size(len(last_idx), 1) + new_output = self.output.update(rows=len(last_idx), cols=1) case (_, 1): - new_output = self.output.set_size(rows, len(last_idx)) + new_output = self.output.update(rows=rows, cols=len(last_idx)) case (1, _): - new_output = self.output.set_size(len(last_idx), cols) + new_output = self.output.update(rows=len(last_idx), cols=cols) case (_, _): raise NotImplementedError ## Not yet :) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py index e881c14..025ef71 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py @@ -226,7 +226,7 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow @@ -335,7 +335,7 @@ class FuncFlow(pyd.BaseModel): disallow_jax: Don't use `self.func_jax` to evaluate, even if possible. This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits. """ - if self.supports_jax: + if self.supports_jax and not disallow_jax: return self.func_jax( *params.scaled_func_args(symbol_values), **params.scaled_func_kwargs(symbol_values), diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py index a204eef..a7fd668 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py @@ -24,8 +24,8 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .array import ArrayFlow @@ -95,10 +95,12 @@ class RangeFlow(pyd.BaseModel): stop: spux.ScalarUnitlessRealExpr steps: int = 0 scaling: ScalingMode = ScalingMode.Lin + ## TODO: No support for non-Lin (yet) unit: spux.Unit | None = None symbols: frozenset[sim_symbols.SimSymbol] = frozenset() + ## TODO: No proper support for symbols (yet) # Helper Attributes pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None @@ -112,18 +114,26 @@ class RangeFlow(pyd.BaseModel): steps: int = 50, scaling: ScalingMode | str = ScalingMode.Lin, ) -> typ.Self: - if sym.domain.start.is_infinite or sym.domain.end.is_infinite: - use_steps = 0 - else: - use_steps = steps + if ( + sym.mathtype is not spux.MathType.Complex + and sym.rows == 1 + and sym.cols == 1 + ): + if sym.domain.inf.is_infinite or sym.domain.sup.is_infinite: + _steps = 0 + else: + _steps = steps - return RangeFlow( - start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1), - stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1), - steps=use_steps, - scaling=ScalingMode(scaling), - unit=sym.unit, - ) + return RangeFlow( + start=sym.domain.inf if sym.domain.inf.is_finite else sp.S(-1), + stop=sym.domain.sup if sym.domain.sup.is_finite else sp.S(1), + steps=_steps, + scaling=ScalingMode(scaling), + unit=sym.unit, + ) + + msg = f'RangeFlow is incompatible with SimSymbol {sym}' + raise ValueError(msg) def to_sym( self, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py index 20d6562..b7a9cb7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py @@ -23,7 +23,7 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow @@ -53,11 +53,6 @@ class ParamsFlow(pyd.BaseModel): sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow ] = pyd.Field(default_factory=dict) - @functools.cached_property - def diff_symbols(self) -> set[sim_symbols.SimSymbol]: - """Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments.""" - return {sym for sym in self.symbols if sym.can_diff} - #################### # - Symbols #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py index 16eeca5..9765991 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py @@ -29,7 +29,7 @@ import tidy3d as td from blender_maxwell.contracts import BLEnumElement from blender_maxwell.services import tdcloud -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .flow_kinds.info import InfoFlow diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py index 0dc4752..ce6259b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py @@ -28,7 +28,7 @@ import typing as typ import sympy.physics.units as spu -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux #################### # - Unit Systems diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py index 4f16660..4242efd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py @@ -20,7 +20,7 @@ import typing as typ import bpy -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .. import bl_socket_map diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py new file mode 100644 index 0000000..04c6341 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py @@ -0,0 +1,29 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .filter import FilterOperation +from .map import MapOperation +from .operate import BinaryOperation +from .reduce import ReduceOperation +from .transform import TransformOperation + +__all__ = [ + 'FilterOperation', + 'MapOperation', + 'BinaryOperation', + 'ReduceOperation', + 'TransformOperation', +] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py new file mode 100644 index 0000000..6bea59f --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py @@ -0,0 +1,233 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.lax as jlax +import jax.numpy as jnp + +from blender_maxwell.utils import logger, sim_symbols + +from .. import contracts as ct + +log = logger.get(__name__) + + +class FilterOperation(enum.StrEnum): + """Valid operations for the `FilterMathNode`. + + Attributes: + DimToVec: Shift last dimension to output. + DimsToMat: Shift last 2 dimensions to output. + PinLen1: Remove a len(1) dimension. + Pin: Remove a len(n) dimension by selecting a particular index. + Swap: Swap the positions of two dimensions. + """ + + # Slice + Slice = enum.auto() + SliceIdx = enum.auto() + + # Pin + PinLen1 = enum.auto() + Pin = enum.auto() + PinIdx = enum.auto() + + # Dimension + Swap = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + FO = FilterOperation + return { + # Slice + FO.Slice: '≈a[v₁:v₂]', + FO.SliceIdx: '=a[i:j]', + # Pin + FO.PinLen1: 'a[0] → a', + FO.Pin: 'a[v] ⇝ a', + FO.PinIdx: 'a[i] → a', + # Reinterpret + FO.Swap: 'a₁ ↔ a₂', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + FO = FilterOperation + return ( + str(self), + FO.to_name(self), + FO.to_name(self), + FO.to_icon(self), + i, + ) + + #################### + # - Ops from Info + #################### + @staticmethod + def by_info(info: ct.InfoFlow) -> list[typ.Self]: + FO = FilterOperation + operations = [] + + # Slice + if info.dims: + operations.append(FO.SliceIdx) + + # Pin + ## PinLen1 + ## -> There must be a dimension with length 1. + if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]: + operations.append(FO.PinLen1) + + ## Pin | PinIdx + ## -> There must be a dimension, full stop. + if info.dims: + operations += [FO.Pin, FO.PinIdx] + + # Reinterpret + ## Swap + ## -> There must be at least two dimensions. + if len(info.dims) >= 2: # noqa: PLR2004 + operations.append(FO.Swap) + + return operations + + #################### + # - Computed Properties + #################### + @property + def func_args(self) -> list[sim_symbols.SimSymbol]: + FO = FilterOperation + return { + # Pin + FO.Pin: [sim_symbols.idx(None)], + FO.PinIdx: [sim_symbols.idx(None)], + }.get(self, []) + + #################### + # - Methods + #################### + @property + def num_dim_inputs(self) -> None: + FO = FilterOperation + return { + # Slice + FO.Slice: 1, + FO.SliceIdx: 1, + # Pin + FO.PinLen1: 1, + FO.Pin: 1, + FO.PinIdx: 1, + # Reinterpret + FO.Swap: 2, + }[self] + + def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: + FO = FilterOperation + match self: + # Slice + case FO.Slice: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + case FO.SliceIdx: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + # Pin + case FO.PinLen1: + return [ + dim + for dim, dim_idx in info.dims.items() + if not info.has_idx_cont(dim) and len(dim_idx) == 1 + ] + + case FO.Pin: + return info.dims + + case FO.PinIdx: + return [dim for dim in info.dims if not info.has_idx_cont(dim)] + + # Dimension + case FO.Swap: + return info.dims + + return [] + + def are_dims_valid( + self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None + ) -> bool: + """Check whether the given dimension inputs are valid in the context of this operation, and of the information.""" + if self.num_dim_inputs == 1: + return dim_0 in self.valid_dims(info) + + if self.num_dim_inputs == 2: # noqa: PLR2004 + valid_dims = self.valid_dims(info) + return dim_0 in valid_dims and dim_1 in valid_dims + + return False + + #################### + # - UI + #################### + def jax_func( + self, + axis_0: int | None, + axis_1: int | None, + slice_tuple: tuple[int, int, int] | None = None, + ): + FO = FilterOperation + return { + # Pin + FO.Slice: lambda expr: jlax.slice_in_dim( + expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 + ), + FO.SliceIdx: lambda expr: jlax.slice_in_dim( + expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 + ), + # Pin + FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0), + FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), + FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), + # Dimension + FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1), + }[self] + + def transform_info( + self, + info: ct.InfoFlow, + dim_0: sim_symbols.SimSymbol, + dim_1: sim_symbols.SimSymbol, + pin_idx: int | None = None, + slice_tuple: tuple[int, int, int] | None = None, + ): + FO = FilterOperation + return { + FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple), + FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), + # Pin + FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + # Reinterpret + FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py new file mode 100644 index 0000000..d708561 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py @@ -0,0 +1,365 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class MapOperation(enum.StrEnum): + """Valid operations for the `MapMathNode`. + + Attributes: + Real: Compute the real part of the input. + Imag: Compute the imaginary part of the input. + Abs: Compute the absolute value of the input. + Sq: Square the input. + Sqrt: Compute the (principal) square root of the input. + InvSqrt: Compute the inverse square root of the input. + Cos: Compute the cosine of the input. + Sin: Compute the sine of the input. + Tan: Compute the tangent of the input. + Acos: Compute the inverse cosine of the input. + Asin: Compute the inverse sine of the input. + Atan: Compute the inverse tangent of the input. + Norm2: Compute the 2-norm (aka. length) of the input vector. + Det: Compute the determinant of the input matrix. + Cond: Compute the condition number of the input matrix. + NormFro: Compute the frobenius norm of the input matrix. + Rank: Compute the rank of the input matrix. + Diag: Compute the diagonal vector of the input matrix. + EigVals: Compute the eigenvalues vector of the input matrix. + SvdVals: Compute the singular values vector of the input matrix. + Inv: Compute the inverse matrix of the input matrix. + Tra: Compute the transpose matrix of the input matrix. + Qr: Compute the QR-factorized matrices of the input matrix. + Chol: Compute the Cholesky-factorized matrices of the input matrix. + Svd: Compute the SVD-factorized matrices of the input matrix. + """ + + # By Number + Real = enum.auto() + Imag = enum.auto() + Abs = enum.auto() + Sq = enum.auto() + Sqrt = enum.auto() + InvSqrt = enum.auto() + Cos = enum.auto() + Sin = enum.auto() + Tan = enum.auto() + Acos = enum.auto() + Asin = enum.auto() + Atan = enum.auto() + Sinc = enum.auto() + # By Vector + Norm2 = enum.auto() + # By Matrix + Det = enum.auto() + Cond = enum.auto() + NormFro = enum.auto() + Rank = enum.auto() + Diag = enum.auto() + EigVals = enum.auto() + SvdVals = enum.auto() + Inv = enum.auto() + Tra = enum.auto() + Qr = enum.auto() + Chol = enum.auto() + Svd = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + MO = MapOperation + return { + # By Number + MO.Real: 'ℝ(v)', + MO.Imag: 'Im(v)', + MO.Abs: '|v|', + MO.Sq: 'v²', + MO.Sqrt: '√v', + MO.InvSqrt: '1/√v', + MO.Cos: 'cos v', + MO.Sin: 'sin v', + MO.Tan: 'tan v', + MO.Acos: 'acos v', + MO.Asin: 'asin v', + MO.Atan: 'atan v', + MO.Sinc: 'sinc v', + # By Vector + MO.Norm2: '||v||₂', + # By Matrix + MO.Det: 'det V', + MO.Cond: 'κ(V)', + MO.NormFro: '||V||_F', + MO.Rank: 'rank V', + MO.Diag: 'diag V', + MO.EigVals: 'eigvals V', + MO.SvdVals: 'svdvals V', + MO.Inv: 'V⁻¹', + MO.Tra: 'Vt', + MO.Qr: 'qr V', + MO.Chol: 'chol V', + MO.Svd: 'svd V', + }[value] + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + MO = MapOperation + return ( + str(self), + MO.to_name(self), + MO.to_name(self), + MO.to_icon(self), + i, + ) + + #################### + # - Ops from Shape + #################### + @staticmethod + def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]: + ## TODO: By info, not shape. + ## TODO: Check valid domains/mathtypes for some functions. + MO = MapOperation + element_ops = [ + MO.Real, + MO.Imag, + MO.Abs, + MO.Sq, + MO.Sqrt, + MO.InvSqrt, + MO.Cos, + MO.Sin, + MO.Tan, + MO.Acos, + MO.Asin, + MO.Atan, + MO.Sinc, + ] + + match (info.output.rows, info.output.cols): + case (1, 1): + return element_ops + + case (_, 1): + return [*element_ops, MO.Norm2] + + case (rows, cols) if rows == cols: + ## TODO: Check hermitian/posdef for cholesky. + ## - Can we even do this with just the output symbol approach? + return [ + *element_ops, + MO.Det, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.Diag, + MO.EigVals, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Qr, + MO.Chol, + MO.Svd, + ] + + case (rows, cols): + return [ + *element_ops, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Svd, + ] + + return [] + + #################### + # - Function Properties + #################### + @property + def sp_func(self): + MO = MapOperation + return { + # By Number + MO.Real: lambda expr: sp.re(expr), + MO.Imag: lambda expr: sp.im(expr), + MO.Abs: lambda expr: sp.Abs(expr), + MO.Sq: lambda expr: expr**2, + MO.Sqrt: lambda expr: sp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr), + MO.Cos: lambda expr: sp.cos(expr), + MO.Sin: lambda expr: sp.sin(expr), + MO.Tan: lambda expr: sp.tan(expr), + MO.Acos: lambda expr: sp.acos(expr), + MO.Asin: lambda expr: sp.asin(expr), + MO.Atan: lambda expr: sp.atan(expr), + MO.Sinc: lambda expr: sp.sinc(expr), + # By Vector + # Vector -> Number + MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0], + # By Matrix + # Matrix -> Number + MO.Det: lambda expr: sp.det(expr), + MO.Cond: lambda expr: expr.condition_number(), + MO.NormFro: lambda expr: expr.norm(ord='fro'), + MO.Rank: lambda expr: expr.rank(), + # Matrix -> Vec + MO.Diag: lambda expr: expr.diagonal(), + MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())), + MO.SvdVals: lambda expr: expr.singular_values(), + # Matrix -> Matrix + MO.Inv: lambda expr: expr.inv(), + MO.Tra: lambda expr: expr.T, + # Matrix -> Matrices + MO.Qr: lambda expr: expr.QRdecomposition(), + MO.Chol: lambda expr: expr.cholesky(), + MO.Svd: lambda expr: expr.singular_value_decomposition(), + }[self] + + @property + def jax_func(self): + MO = MapOperation + return { + # By Number + MO.Real: lambda expr: jnp.real(expr), + MO.Imag: lambda expr: jnp.imag(expr), + MO.Abs: lambda expr: jnp.abs(expr), + MO.Sq: lambda expr: jnp.square(expr), + MO.Sqrt: lambda expr: jnp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr), + MO.Cos: lambda expr: jnp.cos(expr), + MO.Sin: lambda expr: jnp.sin(expr), + MO.Tan: lambda expr: jnp.tan(expr), + MO.Acos: lambda expr: jnp.acos(expr), + MO.Asin: lambda expr: jnp.asin(expr), + MO.Atan: lambda expr: jnp.atan(expr), + MO.Sinc: lambda expr: jnp.sinc(expr), + # By Vector + # Vector -> Number + MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1), + # By Matrix + # Matrix -> Number + MO.Det: lambda expr: jnp.linalg.det(expr), + MO.Cond: lambda expr: jnp.linalg.cond(expr), + MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'), + MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr), + # Matrix -> Vec + MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1), + MO.EigVals: lambda expr: jnp.linalg.eigvals(expr), + MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr), + # Matrix -> Matrix + MO.Inv: lambda expr: jnp.linalg.inv(expr), + MO.Tra: lambda expr: jnp.matrix_transpose(expr), + # Matrix -> Matrices + MO.Qr: lambda expr: jnp.linalg.qr(expr), + MO.Chol: lambda expr: jnp.linalg.cholesky(expr), + MO.Svd: lambda expr: jnp.linalg.svd(expr), + }[self] + + def transform_info(self, info: ct.InfoFlow): + MO = MapOperation + + return { + # By Number + MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Sq: lambda: info, + MO.Sqrt: lambda: info, + MO.InvSqrt: lambda: info, + MO.Cos: lambda: info, + MO.Sin: lambda: info, + MO.Tan: lambda: info, + MO.Acos: lambda: info, + MO.Asin: lambda: info, + MO.Atan: lambda: info, + MO.Sinc: lambda: info, + # By Vector + MO.Norm2: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + # By Matrix + MO.Det: lambda: info.update_output( + rows=1, + cols=1, + ), + MO.Cond: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, + ), + MO.NormFro: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + MO.Rank: lambda: info.update_output( + mathtype=spux.MathType.Integer, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, + # Interval + interval_finite_re=(0, sim_symbols.int_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + # Matrix -> Vector ## TODO: ALL OF THESE + MO.Diag: lambda: info, + MO.EigVals: lambda: info, + MO.SvdVals: lambda: info, + # Matrix -> Matrix ## TODO: ALL OF THESE + MO.Inv: lambda: info, + MO.Tra: lambda: info, + # Matrix -> Matrices ## TODO: ALL OF THESE + MO.Qr: lambda: info, + MO.Chol: lambda: info, + MO.Svd: lambda: info, + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py new file mode 100644 index 0000000..cf3d228 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py @@ -0,0 +1,476 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp +import sympy.physics.quantum as spq +import sympy.physics.units as spu + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType: + """Implement the Hadamard Power. + + Follows the specification in , which also conforms to `numpy` broadcasting rules for `**` on `np.ndarray`. + """ + match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)): + case (False, False): + msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}" + raise ValueError(msg) + + case (True, False): + return lhs.applyfunc(lambda el: el**rhs) + + case (False, True): + return rhs.applyfunc(lambda el: lhs**el) + + case (True, True) if lhs.shape == rhs.shape: + common_shape = lhs.shape + return sp.ImmutableMatrix( + *common_shape, lambda i, j: lhs[i, j] ** rhs[i, j] + ) + + case _: + msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}' + raise ValueError(msg) + + +class BinaryOperation(enum.StrEnum): + """Valid operations for the `OperateMathNode`. + + Attributes: + Mul: Scalar multiplication. + Div: Scalar division. + Pow: Scalar exponentiation. + Add: Elementwise addition. + Sub: Elementwise subtraction. + HadamMul: Elementwise multiplication (hadamard product). + HadamPow: Principled shape-aware exponentiation (hadamard power). + Atan2: Quadrant-respecting 2D arctangent. + VecVecDot: Dot product for identically shaped vectors w/transpose. + Cross: Cross product between identically shaped 3D vectors. + VecVecOuter: Vector-vector outer product. + LinSolve: Solve a linear system. + LsqSolve: Minimize error of an underdetermined linear system. + VecMatOuter: Vector-matrix outer product. + MatMatDot: Matrix-matrix dot product. + """ + + # Number | Number + Mul = enum.auto() + Div = enum.auto() + Pow = enum.auto() + + # Elements | Elements + Add = enum.auto() + Sub = enum.auto() + HadamMul = enum.auto() + HadamPow = enum.auto() + HadamDiv = enum.auto() + Atan2 = enum.auto() + + # Vector | Vector + VecVecDot = enum.auto() + Cross = enum.auto() + VecVecOuter = enum.auto() + + # Matrix | Vector + LinSolve = enum.auto() + LsqSolve = enum.auto() + + # Vector | Matrix + VecMatOuter = enum.auto() + + # Matrix | Matrix + MatMatDot = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + BO = BinaryOperation + return { + # Number | Number + BO.Mul: 'ℓ · r', + BO.Div: 'ℓ / r', + BO.Pow: 'ℓ ^ r', ## Also for square-matrix powers. + # Elements | Elements + BO.Add: 'ℓ + r', + BO.Sub: 'ℓ - r', + BO.HadamMul: '𝐋 ⊙ 𝐑', + BO.HadamDiv: '𝐋 ⊙/ 𝐑', + BO.HadamPow: '𝐥 ⊙^ 𝐫', + BO.Atan2: 'atan2(ℓ:x, r:y)', + # Vector | Vector + BO.VecVecDot: '𝐥 · 𝐫', + BO.Cross: 'cross(𝐥,𝐫)', + BO.VecVecOuter: '𝐥 ⊗ 𝐫', + # Matrix | Vector + BO.LinSolve: '𝐋 ∖ 𝐫', + BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂', + # Vector | Matrix + BO.VecMatOuter: '𝐋 ⊗ 𝐫', + # Matrix | Matrix + BO.MatMatDot: '𝐋 · 𝐑', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + BO = BinaryOperation + return ( + str(self), + BO.to_name(self), + BO.to_name(self), + BO.to_icon(self), + i, + ) + + def bl_enum_elements( + self, info_l: ct.InfoFlow, info_r: ct.InfoFlow + ) -> list[ct.BLEnumElement]: + """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s. + + Returns a `bpy.props.EnumProperty.items`-compatible list. + """ + return [ + operation.bl_enum_element(i) + for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r)) + ] + + #################### + # - Ops from Shape + #################### + @staticmethod + def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]: + """Deduce valid binary operations from the shapes of the inputs.""" + BO = BinaryOperation + ops = [] + + # Add/Sub + if info_l.compare_addable(info_r, allow_differing_unit=True): + ops += [BO.Add, BO.Sub] + + # Mul/Div + ## -> Mul is ambiguous; we differentiate Hadamard and Standard. + ## -> Div additionally requires non-zero guarantees. + if info_l.compare_multiplicable(info_r): + match (info_l.order, info_r.order, info_r.output.is_nonzero): + case (ordl, ordr, True) if ordl == 0 and ordr == 0: + ops += [BO.Mul, BO.Div] + case (ordl, ordr, True) if ordl > 0 and ordr == 0: + ops += [BO.Mul, BO.Div] + case (ordl, ordr, True) if ordl == 0 and ordr > 0: + ops += [BO.Mul] + case (ordl, ordr, True) if ordl > 0 and ordr > 0: + ops += [BO.HadamMul, BO.HadamDiv] + + case (ordl, ordr, False) if ordl == 0 and ordr == 0: + ops += [BO.Mul] + case (ordl, ordr, False) if ordl > 0 and ordr == 0: + ops += [BO.Mul] + case (ordl, ordr, True) if ordl == 0 and ordr > 0: + ops += [BO.Mul] + case (ordl, ordr, False) if ordl > 0 and ordr > 0: + ops += [BO.HadamMul] + + # Pow + ## -> We distinguish between "Hadamard Power" and "Power". + ## -> For scalars, they are the same (but we only expose "power"). + ## -> For matrices, square matrices can be exp'ed by int powers. + ## -> Any other combination is well-defined by the Hadamard Power. + if info_l.compare_exponentiable(info_r): + match (info_l.order, info_r.order, info_r.output.mathtype): + case (ordl, ordr, _) if ordl == 0 and ordr == 0: + ops += [BO.Pow] + + case (ordl, ordr, spux.MathType.Integer) if ( + ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols + ): + ops += [BO.Pow, BO.HadamPow] + + case _: + ops += [BO.HadamPow] + + # Operations by-Output Length + match ( + info_l.output.shape_len, + info_r.output.shape_len, + ): + # Number | Number + case (0, 0) if info_l.is_scalar and info_r.is_scalar: + # atan2: PhysicalType Must Both be Length | NonPhysical + ## -> atan2() produces radians from Cartesian coordinates. + ## -> This wouldn't make sense on non-Length / non-Unitless. + if ( + info_l.output.physical_type is spux.PhysicalType.Length + and info_r.output.physical_type is spux.PhysicalType.Length + ) or ( + info_l.output.physical_type is spux.PhysicalType.NonPhysical + and info_l.output.unit is None + and info_r.output.physical_type is spux.PhysicalType.NonPhysical + and info_r.output.unit is None + ): + ops += [BO.Atan2] + + return ops + + # Vector | Vector + case (1, 1) if info_l.compare_dims_identical(info_r): + outl = info_l.output + outr = info_r.output + + # 1D Orders: Outer Product is Valid + ## -> We can't do per-element outer product. + ## -> However, it's still super useful on its own. + if info_l.order == 1 and info_r.order == 1: + ops += [BO.VecVecOuter] + + # Vector | Vector + if outl.rows > outl.cols and outr.rows > outr.cols: + ops += [BO.VecVecDot] + + # Covector | Vector + if outl.rows < outl.cols and outr.rows > outr.cols: + ops += [BO.MatMatDot] + + # Vector | Covector + if outl.rows > outl.cols and outr.rows < outr.cols: + ops += [BO.MatMatDot] + + # Covector | Covector + if outl.rows < outl.cols and outr.rows < outr.cols: + ops += [BO.VecVecDot] + + # Cross Product + ## -> Works great element-wise. + ## -> Enforce that both are 3x1 or 1x3. + ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross + if (outl.rows == 3 and outr.rows == 3) or ( + outl.cols == 3 and outl.cols == 3 + ): + ops += [BO.Cross] + + # Vector | Matrix + ## -> We can't do per-element outer product. + ## -> However, it's still super useful on its own. + case (1, 2) if info_l.compare_dims_identical( + info_r + ) and info_l.order == 1 and info_r.order == 2: + ops += [BO.VecMatOuter] + + # Matrix | Vector + case (2, 1) if info_l.compare_dims_identical(info_r): + # Mat-Vec Dot: Enforce RHS Column Vector + if outr.rows > outl.cols: + ops += [BO.MatMatDot] + + ops += [BO.LinSolve, BO.LsqSolve] + + ## Matrix | Matrix + case (2, 2): + ops += [BO.MatMatDot] + + return ops + + #################### + # - Function Properties + #################### + @property + def sp_func(self): + """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs.""" + BO = BinaryOperation + + ## TODO: Make this compatible with sp.Matrix inputs + return { + # Number | Number + BO.Mul: lambda exprs: exprs[0] * exprs[1], + BO.Div: lambda exprs: exprs[0] / exprs[1], + BO.Pow: lambda exprs: exprs[0] ** exprs[1], + # Elements | Elements + BO.Add: lambda exprs: exprs[0] + exprs[1], + BO.Sub: lambda exprs: exprs[0] - exprs[1], + BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]), + BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), + BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), + # Vector | Vector + BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0], + BO.Cross: lambda exprs: exprs[0].cross(exprs[1]), + BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T, + # Matrix | Vector + BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]), + BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]), + # Vector | Matrix + BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]), + # Matrix | Matrix + BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1], + }[self] + + @property + def unit_func(self): + """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output.""" + BO = BinaryOperation + + ## TODO: Make this compatible with sp.Matrix inputs + return { + # Number | Number + BO.Mul: BO.Mul.sp_func, + BO.Div: BO.Div.sp_func, + BO.Pow: BO.Pow.sp_func, + # Elements | Elements + BO.Add: BO.Add.sp_func, + BO.Sub: BO.Sub.sp_func, + BO.HadamMul: BO.Mul.sp_func, + # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), + BO.Atan2: lambda _: spu.radian, + # Vector | Vector + BO.VecVecDot: BO.Mul.sp_func, + BO.Cross: BO.Mul.sp_func, + BO.VecVecOuter: BO.Mul.sp_func, + # Matrix | Vector + ## -> A,b in Ax = b have units, and the equality must hold. + ## -> Therefore, A \ b must have the units [b]/[A]. + BO.LinSolve: lambda exprs: exprs[1] / exprs[0], + BO.LsqSolve: lambda exprs: exprs[1] / exprs[0], + # Vector | Matrix + BO.VecMatOuter: BO.Mul.sp_func, + # Matrix | Matrix + BO.MatMatDot: BO.Mul.sp_func, + }[self] + + @property + def jax_func(self): + """Deduce an appropriate jax-based function that implements the binary operation for array inputs.""" + ## TODO: Scale the units of one side to the other. + BO = BinaryOperation + + return { + # Number | Number + BO.Mul: lambda exprs: exprs[0] * exprs[1], + BO.Div: lambda exprs: exprs[0] / exprs[1], + BO.Pow: lambda exprs: exprs[0] ** exprs[1], + # Elements | Elements + BO.Add: lambda exprs: exprs[0] + exprs[1], + BO.Sub: lambda exprs: exprs[0] - exprs[1], + BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]), + BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise( + exprs[1].applyfunc(lambda el: 1 / el) + ), + BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]), + BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]), + # Vector | Vector + BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]), + BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]), + BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), + # Matrix | Vector + BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), + BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), + # Vector | Matrix + BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), + # Matrix | Matrix + BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]), + }[self] + + #################### + # - Transforms + #################### + def transform_funcs(self, func_l: ct.FuncFlow, func_r: ct.FuncFlow) -> ct.FuncFlow: + """Transform two input functions according to the current operation.""" + BO = BinaryOperation + + # Add/Sub: Normalize Unit of RHS to LHS + ## -> We can only add/sub identical units. + ## -> To be nice, we only require identical PhysicalType. + ## -> The result of a binary operation should have one unit. + if self is BO.Add or self is BO.Sub: + norm_func_r = func_r.scale_to_unit(func_l.func_output.unit) + else: + norm_func_r = func_r + + return (func_l, norm_func_r).compose_within( + self.jax_func, + enclosing_func_output=self.transform_outputs( + func_l.func_output, norm_func_r.func_output + ), + supports_jax=True, + ) + + def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): + """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression. + + Warnings: + `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r). + + If not, bad things will happen. + """ + return info_l.operate_output( + info_r, + lambda a, b: self.sp_func([a, b]), + lambda a, b: self.unit_func([a, b]), + ) + + #################### + # - InfoFlow Transform + #################### + def transform_outputs( + self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol + ) -> sim_symbols.SimSymbol: + # TO = TransformOperation + return None + # match self: + # # Number | Number + # case TO.Mul: + # return + # case TO.Div: + # case TO.Pow: + + # # Elements | Elements + # Add = enum.auto() + # Sub = enum.auto() + # HadamMul = enum.auto() + # HadamPow = enum.auto() + # HadamDiv = enum.auto() + # Atan2 = enum.auto() + + # # Vector | Vector + # VecVecDot = enum.auto() + # Cross = enum.auto() + # VecVecOuter = enum.auto() + + # # Matrix | Vector + # LinSolve = enum.auto() + # LsqSolve = enum.auto() + + # # Vector | Matrix + # VecMatOuter = enum.auto() + + # # Matrix | Matrix + # MatMatDot = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py new file mode 100644 index 0000000..4d26985 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py @@ -0,0 +1,116 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class ReduceOperation(enum.StrEnum): + # Summary + Count = enum.auto() + + # Statistics + Mean = enum.auto() + Std = enum.auto() + Var = enum.auto() + + StdErr = enum.auto() + + Min = enum.auto() + Q25 = enum.auto() + Median = enum.auto() + Q75 = enum.auto() + Max = enum.auto() + + Mode = enum.auto() + + # Reductions + Sum = enum.auto() + Prod = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + RO = ReduceOperation + return { + # Summary + RO.Count: '# [a]', + RO.Mode: 'mode [a]', + # Statistics + RO.Mean: 'μ [a]', + RO.Std: 'σ [a]', + RO.Var: 'σ² [a]', + RO.StdErr: 'stderr [a]', + RO.Min: 'min [a]', + RO.Q25: 'q₂₅ [a]', + RO.Median: 'median [a]', + RO.Q75: 'q₇₅ [a]', + RO.Min: 'max [a]', + # Reductions + RO.Sum: 'sum [a]', + RO.Prod: 'prod [a]', + }[value] + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + RO = ReduceOperation + return ( + str(self), + RO.to_name(self), + RO.to_name(self), + RO.to_icon(self), + i, + ) + + #################### + # - Derivation + #################### + @staticmethod + def from_info(info: ct.InfoFlow) -> list[typ.Self]: + """Derive valid reduction operations from the `InfoFlow` of the operand.""" + pass + + #################### + # - Composable Functions + #################### + @property + def jax_func(self): + RO = ReduceOperation + return {}[self] + + #################### + # - Transforms + #################### + def transform_info(self, info: ct.InfoFlow): + pass diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py new file mode 100644 index 0000000..499c082 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py @@ -0,0 +1,336 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import jaxtyping as jtyp + +from blender_maxwell.utils import logger, sci_constants, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class TransformOperation(enum.StrEnum): + """Valid operations for the `TransformMathNode`. + + Attributes: + FreqToVacWL: Transform an frequency dimension to vacuum wavelength. + VacWLToFreq: Transform a vacuum wavelength dimension to frequency. + ConvertIdxUnit: Convert the unit of a dimension to a compatible unit. + SetIdxUnit: Set all properties of a dimension. + FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it. + **For 2D integer-indexed data only**. + + IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type. + DimToVec: Fold the last dimension into the scalar output, creating a vector output type. + DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type. + FT: Compute the 1D fourier transform along a dimension. + New dimensional bounds are computing using the Nyquist Limit. + For higher dimensions, simply repeat along more dimensions. + InvFT1D: Compute the inverse 1D fourier transform along a dimension. + New dimensional bounds are computing using the Nyquist Limit. + For higher dimensions, simply repeat along more dimensions. + """ + + # Covariant Transform + FreqToVacWL = enum.auto() + VacWLToFreq = enum.auto() + ConvertIdxUnit = enum.auto() + SetIdxUnit = enum.auto() + FirstColToFirstIdx = enum.auto() + + # Fold + IntDimToComplex = enum.auto() + DimToVec = enum.auto() + DimsToMat = enum.auto() + + # Fourier + FT1D = enum.auto() + InvFT1D = enum.auto() + + # TODO: Affine + ## TODO + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + TO = TransformOperation + return { + # Covariant Transform + TO.FreqToVacWL: '𝑓 → λᵥ', + TO.VacWLToFreq: 'λᵥ → 𝑓', + TO.ConvertIdxUnit: 'Convert Dim', + TO.SetIdxUnit: 'Set Dim', + TO.FirstColToFirstIdx: '1st Col → 1st Dim', + # Fold + TO.IntDimToComplex: '→ ℂ', + TO.DimToVec: '→ Vector', + TO.DimsToMat: '→ Matrix', + ## TODO: Vector to new last-dim integer + ## TODO: Matrix to two last-dim integers + # Fourier + TO.FT1D: 'FT', + TO.InvFT1D: 'iFT', + }[value] + + @property + def name(self) -> str: + return TransformOperation.to_name(self) + + @staticmethod + def to_icon(_: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + TO = TransformOperation + return ( + str(self), + TO.to_name(self), + TO.to_name(self), + TO.to_icon(self), + i, + ) + + #################### + # - Methods + #################### + def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: + TO = TransformOperation + match self: + case TO.FreqToVacWL: + return [ + dim + for dim in info.dims + if dim.physical_type is spux.PhysicalType.Freq + ] + + case TO.VacWLToFreq: + return [ + dim + for dim in info.dims + if dim.physical_type is spux.PhysicalType.Length + ] + + case TO.ConvertIdxUnit: + return [ + dim + for dim in info.dims + if not info.has_idx_labels(dim) + and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None + ] + + case TO.SetIdxUnit: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + ## ColDimToComplex: Implicit Last Dimension + ## DimToVec: Implicit Last Dimension + ## DimsToMat: Implicit Last 2 Dimensions + + case TO.FT1D | TO.InvFT1D: + # Filter by Axis Uniformity + ## -> FT requires uniform axis (aka. must be RangeFlow). + ## -> NOTE: If FT isn't popping up, check ExtractDataNode. + return [dim for dim in info.dims if info.is_idx_uniform(dim)] + + return [] + + @staticmethod + def by_info(info: ct.InfoFlow) -> list[typ.Self]: + TO = TransformOperation + operations = [] + + # Covariant Transform + ## Freq -> VacWL + if TO.FreqToVacWL.valid_dims(info): + operations += [TO.FreqToVacWL] + + ## VacWL -> Freq + if TO.VacWLToFreq.valid_dims(info): + operations += [TO.VacWLToFreq] + + ## Convert Index Unit + if TO.ConvertIdxUnit.valid_dims(info): + operations += [TO.ConvertIdxUnit] + + if TO.SetIdxUnit.valid_dims(info): + operations += [TO.SetIdxUnit] + + ## Column to First Index (Array) + if ( + len(info.dims) == 2 # noqa: PLR2004 + and info.first_dim.mathtype is spux.MathType.Integer + and info.last_dim.mathtype is spux.MathType.Integer + and info.output.shape_len == 0 + ): + operations += [TO.FirstColToFirstIdx] + + # Fold + ## Last Dim -> Complex + if ( + len(info.dims) >= 1 + and ( + info.output.mathtype + in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real] + ) + and info.last_dim.mathtype is spux.MathType.Integer + and info.has_idx_labels(info.last_dim) + and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004 + ): + operations += [TO.IntDimToComplex] + + ## Last Dim -> Vector + if len(info.dims) >= 1 and info.output.shape_len == 0: + operations += [TO.DimToVec] + + ## Last Dim -> Matrix + if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004 + operations += [TO.DimsToMat] + + # Fourier + if TO.FT1D.valid_dims(info): + operations += [TO.FT1D] + + if TO.InvFT1D.valid_dims(info): + operations += [TO.InvFT1D] + + return operations + + #################### + # - Function Properties + #################### + def jax_func(self, axis: int | None = None): + TO = TransformOperation + return { + # Covariant Transform + ## -> Freq <-> WL is a rescale (noop) AND flip (not noop). + TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis), + TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis), + TO.ConvertIdxUnit: lambda expr: expr, + TO.SetIdxUnit: lambda expr: expr, + TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1), + # Fold + ## -> To Complex: This should generally be a no-op. + TO.IntDimToComplex: lambda expr: jnp.squeeze( + expr.view(dtype=jnp.complex64), axis=-1 + ), + TO.DimToVec: lambda expr: expr, + TO.DimsToMat: lambda expr: expr, + # Fourier + TO.FT1D: lambda expr: jnp.fft(expr, axis=axis), + TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis), + }[self] + + def transform_info( + self, + info: ct.InfoFlow, + dim: sim_symbols.SimSymbol | None = None, + data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None, + new_dim_name: str | None = None, + unit: spux.Unit | None = None, + physical_type: spux.PhysicalType | None = None, + ) -> ct.InfoFlow: + TO = TransformOperation + return { + # Covariant Transform + TO.FreqToVacWL: lambda: info.replace_dim( + (f_dim := dim), + sim_symbols.wl(unit), + info.dims[f_dim].rescale( + lambda el: sci_constants.vac_speed_of_light / el, + reverse=True, + new_unit=unit, + ), + ), + TO.VacWLToFreq: lambda: info.replace_dim( + (wl_dim := dim), + sim_symbols.freq(unit), + info.dims[wl_dim].rescale( + lambda el: sci_constants.vac_speed_of_light / el, + reverse=True, + new_unit=unit, + ), + ), + TO.ConvertIdxUnit: lambda: info.replace_dim( + dim, + dim.update(unit=unit), + ( + info.dims[dim].rescale_to_unit(unit) + if info.has_idx_discrete(dim) + else None ## Continuous -- dim SimSymbol already scaled + ), + ), + TO.SetIdxUnit: lambda: info.replace_dim( + dim, + dim.update( + sym_name=new_dim_name, + physical_type=physical_type, + unit=unit, + ), + ( + info.dims[dim].correct_unit(unit) + if info.has_idx_discrete(dim) + else None ## Continuous -- dim SimSymbol already scaled + ), + ), + TO.FirstColToFirstIdx: lambda: info.replace_dim( + info.first_dim, + info.first_dim.update( + sym_name=new_dim_name, + mathtype=spux.MathType.from_jax_array(data_col), + physical_type=physical_type, + unit=unit, + ), + ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)), + ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)), + # Fold + TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( + mathtype=spux.MathType.Complex + ), + TO.DimToVec: lambda: info.fold_last_input(), + TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), + # Fourier + TO.FT1D: lambda: info.replace_dim( + dim, + [ + # FT'ed Unit: Reciprocal of the Original Unit + dim.update( + unit=1 / dim.unit if dim.unit is not None else 1 + ), ## TODO: Okay to not scale interval? + # FT'ed Bounds: Reciprocal of the Original Unit + info.dims[dim].bound_fourier_transform, + ], + ), + TO.InvFT1D: lambda: info.replace_dim( + info.last_dim, + [ + # FT'ed Unit: Reciprocal of the Original Unit + dim.update( + unit=1 / dim.unit if dim.unit is not None else 1 + ), ## TODO: Okay to not scale interval? + # FT'ed Bounds: Reciprocal of the Original Unit + ## -> Note the midpoint may revert to 0. + ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more. + info.dims[dim].bound_inv_fourier_transform, + ], + ), + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index e036857..a2f54db 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -26,7 +26,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index eddeff4..6c02ec1 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -20,225 +20,15 @@ import enum import typing as typ import bpy -import jax.lax as jlax -import jax.numpy as jnp import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events -log = logger.get(__name__) - - -class FilterOperation(enum.StrEnum): - """Valid operations for the `FilterMathNode`. - - Attributes: - DimToVec: Shift last dimension to output. - DimsToMat: Shift last 2 dimensions to output. - PinLen1: Remove a len(1) dimension. - Pin: Remove a len(n) dimension by selecting a particular index. - Swap: Swap the positions of two dimensions. - """ - - # Slice - Slice = enum.auto() - SliceIdx = enum.auto() - - # Pin - PinLen1 = enum.auto() - Pin = enum.auto() - PinIdx = enum.auto() - - # Dimension - Swap = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - FO = FilterOperation - return { - # Slice - FO.Slice: '≈a[v₁:v₂]', - FO.SliceIdx: '=a[i:j]', - # Pin - FO.PinLen1: 'a[0] → a', - FO.Pin: 'a[v] ⇝ a', - FO.PinIdx: 'a[i] → a', - # Reinterpret - FO.Swap: 'a₁ ↔ a₂', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - FO = FilterOperation - return ( - str(self), - FO.to_name(self), - FO.to_name(self), - FO.to_icon(self), - i, - ) - - #################### - # - Ops from Info - #################### - @staticmethod - def by_info(info: ct.InfoFlow) -> list[typ.Self]: - FO = FilterOperation - operations = [] - - # Slice - if info.dims: - operations.append(FO.SliceIdx) - - # Pin - ## PinLen1 - ## -> There must be a dimension with length 1. - if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]: - operations.append(FO.PinLen1) - - ## Pin | PinIdx - ## -> There must be a dimension, full stop. - if info.dims: - operations += [FO.Pin, FO.PinIdx] - - # Reinterpret - ## Swap - ## -> There must be at least two dimensions. - if len(info.dims) >= 2: # noqa: PLR2004 - operations.append(FO.Swap) - - return operations - - #################### - # - Computed Properties - #################### - @property - def func_args(self) -> list[sim_symbols.SimSymbol]: - FO = FilterOperation - return { - # Pin - FO.Pin: [sim_symbols.idx(None)], - FO.PinIdx: [sim_symbols.idx(None)], - }.get(self, []) - - #################### - # - Methods - #################### - @property - def num_dim_inputs(self) -> None: - FO = FilterOperation - return { - # Slice - FO.Slice: 1, - FO.SliceIdx: 1, - # Pin - FO.PinLen1: 1, - FO.Pin: 1, - FO.PinIdx: 1, - # Reinterpret - FO.Swap: 2, - }[self] - - def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: - FO = FilterOperation - match self: - # Slice - case FO.Slice: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - case FO.SliceIdx: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - # Pin - case FO.PinLen1: - return [ - dim - for dim, dim_idx in info.dims.items() - if not info.has_idx_cont(dim) and len(dim_idx) == 1 - ] - - case FO.Pin: - return info.dims - - case FO.PinIdx: - return [dim for dim in info.dims if not info.has_idx_cont(dim)] - - # Dimension - case FO.Swap: - return info.dims - - return [] - - def are_dims_valid( - self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None - ) -> bool: - """Check whether the given dimension inputs are valid in the context of this operation, and of the information.""" - if self.num_dim_inputs == 1: - return dim_0 in self.valid_dims(info) - - if self.num_dim_inputs == 2: # noqa: PLR2004 - valid_dims = self.valid_dims(info) - return dim_0 in valid_dims and dim_1 in valid_dims - - return False - - #################### - # - UI - #################### - def jax_func( - self, - axis_0: int | None, - axis_1: int | None, - slice_tuple: tuple[int, int, int] | None = None, - ): - FO = FilterOperation - return { - # Pin - FO.Slice: lambda expr: jlax.slice_in_dim( - expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 - ), - FO.SliceIdx: lambda expr: jlax.slice_in_dim( - expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 - ), - # Pin - FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0), - FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), - FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), - # Dimension - FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1), - }[self] - - def transform_info( - self, - info: ct.InfoFlow, - dim_0: sim_symbols.SimSymbol, - dim_1: sim_symbols.SimSymbol, - pin_idx: int | None = None, - slice_tuple: tuple[int, int, int] | None = None, - ): - FO = FilterOperation - return { - FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple), - FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), - # Pin - FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - # Reinterpret - FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), - }[self]() - class FilterMathNode(base.MaxwellSimNode): r"""Applies a function that operates on the shape of the array. @@ -304,7 +94,7 @@ class FilterMathNode(base.MaxwellSimNode): #################### # - Properties: Operation #################### - operation: FilterOperation = bl_cache.BLField( + operation: math_system.FilterOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -313,7 +103,9 @@ class FilterMathNode(base.MaxwellSimNode): if self.expr_info is not None: return [ operation.bl_enum_element(i) - for i, operation in enumerate(FilterOperation.by_info(self.expr_info)) + for i, operation in enumerate( + math_system.FilterOperation.by_info(self.expr_info) + ) ] return [] @@ -358,7 +150,7 @@ class FilterMathNode(base.MaxwellSimNode): # - UI #################### def draw_label(self): - FO = FilterOperation + FO = math_system.FilterOperation match self.operation: # Slice case FO.SliceIdx: @@ -398,7 +190,7 @@ class FilterMathNode(base.MaxwellSimNode): row.prop(self, self.blfields['active_dim_0'], text='') row.prop(self, self.blfields['active_dim_1'], text='') - if self.operation is FilterOperation.SliceIdx: + if self.operation is math_system.FilterOperation.SliceIdx: layout.prop(self, self.blfields['slice_tuple'], text='') #################### @@ -434,7 +226,7 @@ class FilterMathNode(base.MaxwellSimNode): ## -> Works with continuous / discrete indexes. ## -> The user will be given a socket w/correct mathtype, unit, etc. . if ( - props['operation'] is FilterOperation.Pin + props['operation'] is math_system.FilterOperation.Pin and dim_0 is not None and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0)) ): @@ -460,7 +252,7 @@ class FilterMathNode(base.MaxwellSimNode): # Loose Sockets: Pin Dim by-Value ## -> Works with discrete points / labelled integers. elif ( - props['operation'] is FilterOperation.PinIdx + props['operation'] is math_system.FilterOperation.PinIdx and dim_0 is not None and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0)) ): @@ -594,7 +386,10 @@ class FilterMathNode(base.MaxwellSimNode): # Pin by-Value: Compute Nearest IDX ## -> Presume a sorted index array to be able to use binary search. - if props['operation'] is FilterOperation.Pin and has_pinned_value: + if ( + props['operation'] is math_system.FilterOperation.Pin + and has_pinned_value + ): nearest_idx_to_value = info.dims[dim_0].nearest_idx_of( pinned_value, require_sorted=True ) @@ -605,7 +400,10 @@ class FilterMathNode(base.MaxwellSimNode): ) # Pin by-Index - if props['operation'] is FilterOperation.PinIdx and has_pinned_axis: + if ( + props['operation'] is math_system.FilterOperation.PinIdx + and has_pinned_axis + ): return params.compose_within( enclosing_arg_targets=[sim_symbols.idx(None)], enclosing_func_args=[sp.S(pinned_axis)], diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 04765a8..4959822 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -16,363 +16,19 @@ """Declares `MapMathNode`.""" -import enum import typing as typ import bpy -import jax.numpy as jnp -import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, logger from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class MapOperation(enum.StrEnum): - """Valid operations for the `MapMathNode`. - - Attributes: - Real: Compute the real part of the input. - Imag: Compute the imaginary part of the input. - Abs: Compute the absolute value of the input. - Sq: Square the input. - Sqrt: Compute the (principal) square root of the input. - InvSqrt: Compute the inverse square root of the input. - Cos: Compute the cosine of the input. - Sin: Compute the sine of the input. - Tan: Compute the tangent of the input. - Acos: Compute the inverse cosine of the input. - Asin: Compute the inverse sine of the input. - Atan: Compute the inverse tangent of the input. - Norm2: Compute the 2-norm (aka. length) of the input vector. - Det: Compute the determinant of the input matrix. - Cond: Compute the condition number of the input matrix. - NormFro: Compute the frobenius norm of the input matrix. - Rank: Compute the rank of the input matrix. - Diag: Compute the diagonal vector of the input matrix. - EigVals: Compute the eigenvalues vector of the input matrix. - SvdVals: Compute the singular values vector of the input matrix. - Inv: Compute the inverse matrix of the input matrix. - Tra: Compute the transpose matrix of the input matrix. - Qr: Compute the QR-factorized matrices of the input matrix. - Chol: Compute the Cholesky-factorized matrices of the input matrix. - Svd: Compute the SVD-factorized matrices of the input matrix. - """ - - # By Number - Real = enum.auto() - Imag = enum.auto() - Abs = enum.auto() - Sq = enum.auto() - Sqrt = enum.auto() - InvSqrt = enum.auto() - Cos = enum.auto() - Sin = enum.auto() - Tan = enum.auto() - Acos = enum.auto() - Asin = enum.auto() - Atan = enum.auto() - Sinc = enum.auto() - # By Vector - Norm2 = enum.auto() - # By Matrix - Det = enum.auto() - Cond = enum.auto() - NormFro = enum.auto() - Rank = enum.auto() - Diag = enum.auto() - EigVals = enum.auto() - SvdVals = enum.auto() - Inv = enum.auto() - Tra = enum.auto() - Qr = enum.auto() - Chol = enum.auto() - Svd = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - MO = MapOperation - return { - # By Number - MO.Real: 'ℝ(v)', - MO.Imag: 'Im(v)', - MO.Abs: '|v|', - MO.Sq: 'v²', - MO.Sqrt: '√v', - MO.InvSqrt: '1/√v', - MO.Cos: 'cos v', - MO.Sin: 'sin v', - MO.Tan: 'tan v', - MO.Acos: 'acos v', - MO.Asin: 'asin v', - MO.Atan: 'atan v', - MO.Sinc: 'sinc v', - # By Vector - MO.Norm2: '||v||₂', - # By Matrix - MO.Det: 'det V', - MO.Cond: 'κ(V)', - MO.NormFro: '||V||_F', - MO.Rank: 'rank V', - MO.Diag: 'diag V', - MO.EigVals: 'eigvals V', - MO.SvdVals: 'svdvals V', - MO.Inv: 'V⁻¹', - MO.Tra: 'Vt', - MO.Qr: 'qr V', - MO.Chol: 'chol V', - MO.Svd: 'svd V', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - MO = MapOperation - return ( - str(self), - MO.to_name(self), - MO.to_name(self), - MO.to_icon(self), - i, - ) - - #################### - # - Ops from Shape - #################### - @staticmethod - def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]: - ## TODO: By info, not shape. - ## TODO: Check valid domains/mathtypes for some functions. - MO = MapOperation - element_ops = [ - MO.Real, - MO.Imag, - MO.Abs, - MO.Sq, - MO.Sqrt, - MO.InvSqrt, - MO.Cos, - MO.Sin, - MO.Tan, - MO.Acos, - MO.Asin, - MO.Atan, - MO.Sinc, - ] - - match (info.output.rows, info.output.cols): - case (1, 1): - return element_ops - - case (_, 1): - return [*element_ops, MO.Norm2] - - case (rows, cols) if rows == cols: - ## TODO: Check hermitian/posdef for cholesky. - ## - Can we even do this with just the output symbol approach? - return [ - *element_ops, - MO.Det, - MO.Cond, - MO.NormFro, - MO.Rank, - MO.Diag, - MO.EigVals, - MO.SvdVals, - MO.Inv, - MO.Tra, - MO.Qr, - MO.Chol, - MO.Svd, - ] - - case (rows, cols): - return [ - *element_ops, - MO.Cond, - MO.NormFro, - MO.Rank, - MO.SvdVals, - MO.Inv, - MO.Tra, - MO.Svd, - ] - - return [] - - #################### - # - Function Properties - #################### - @property - def sp_func(self): - MO = MapOperation - return { - # By Number - MO.Real: lambda expr: sp.re(expr), - MO.Imag: lambda expr: sp.im(expr), - MO.Abs: lambda expr: sp.Abs(expr), - MO.Sq: lambda expr: expr**2, - MO.Sqrt: lambda expr: sp.sqrt(expr), - MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr), - MO.Cos: lambda expr: sp.cos(expr), - MO.Sin: lambda expr: sp.sin(expr), - MO.Tan: lambda expr: sp.tan(expr), - MO.Acos: lambda expr: sp.acos(expr), - MO.Asin: lambda expr: sp.asin(expr), - MO.Atan: lambda expr: sp.atan(expr), - MO.Sinc: lambda expr: sp.sinc(expr), - # By Vector - # Vector -> Number - MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0], - # By Matrix - # Matrix -> Number - MO.Det: lambda expr: sp.det(expr), - MO.Cond: lambda expr: expr.condition_number(), - MO.NormFro: lambda expr: expr.norm(ord='fro'), - MO.Rank: lambda expr: expr.rank(), - # Matrix -> Vec - MO.Diag: lambda expr: expr.diagonal(), - MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())), - MO.SvdVals: lambda expr: expr.singular_values(), - # Matrix -> Matrix - MO.Inv: lambda expr: expr.inv(), - MO.Tra: lambda expr: expr.T, - # Matrix -> Matrices - MO.Qr: lambda expr: expr.QRdecomposition(), - MO.Chol: lambda expr: expr.cholesky(), - MO.Svd: lambda expr: expr.singular_value_decomposition(), - }[self] - - @property - def jax_func(self): - MO = MapOperation - return { - # By Number - MO.Real: lambda expr: jnp.real(expr), - MO.Imag: lambda expr: jnp.imag(expr), - MO.Abs: lambda expr: jnp.abs(expr), - MO.Sq: lambda expr: jnp.square(expr), - MO.Sqrt: lambda expr: jnp.sqrt(expr), - MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr), - MO.Cos: lambda expr: jnp.cos(expr), - MO.Sin: lambda expr: jnp.sin(expr), - MO.Tan: lambda expr: jnp.tan(expr), - MO.Acos: lambda expr: jnp.acos(expr), - MO.Asin: lambda expr: jnp.asin(expr), - MO.Atan: lambda expr: jnp.atan(expr), - MO.Sinc: lambda expr: jnp.sinc(expr), - # By Vector - # Vector -> Number - MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1), - # By Matrix - # Matrix -> Number - MO.Det: lambda expr: jnp.linalg.det(expr), - MO.Cond: lambda expr: jnp.linalg.cond(expr), - MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'), - MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr), - # Matrix -> Vec - MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1), - MO.EigVals: lambda expr: jnp.linalg.eigvals(expr), - MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr), - # Matrix -> Matrix - MO.Inv: lambda expr: jnp.linalg.inv(expr), - MO.Tra: lambda expr: jnp.matrix_transpose(expr), - # Matrix -> Matrices - MO.Qr: lambda expr: jnp.linalg.qr(expr), - MO.Chol: lambda expr: jnp.linalg.cholesky(expr), - MO.Svd: lambda expr: jnp.linalg.svd(expr), - }[self] - - def transform_info(self, info: ct.InfoFlow): - MO = MapOperation - - return { - # By Number - MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Sq: lambda: info, - MO.Sqrt: lambda: info, - MO.InvSqrt: lambda: info, - MO.Cos: lambda: info, - MO.Sin: lambda: info, - MO.Tan: lambda: info, - MO.Acos: lambda: info, - MO.Asin: lambda: info, - MO.Atan: lambda: info, - MO.Sinc: lambda: info, - # By Vector - MO.Norm2: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - # Interval - interval_finite_re=(0, sim_symbols.float_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - # By Matrix - MO.Det: lambda: info.update_output( - rows=1, - cols=1, - ), - MO.Cond: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - physical_type=spux.PhysicalType.NonPhysical, - unit=None, - ), - MO.NormFro: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - # Interval - interval_finite_re=(0, sim_symbols.float_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - MO.Rank: lambda: info.update_output( - mathtype=spux.MathType.Integer, - rows=1, - cols=1, - physical_type=spux.PhysicalType.NonPhysical, - unit=None, - # Interval - interval_finite_re=(0, sim_symbols.int_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - # Matrix -> Vector ## TODO: ALL OF THESE - MO.Diag: lambda: info, - MO.EigVals: lambda: info, - MO.SvdVals: lambda: info, - # Matrix -> Matrix ## TODO: ALL OF THESE - MO.Inv: lambda: info, - MO.Tra: lambda: info, - # Matrix -> Matrices ## TODO: ALL OF THESE - MO.Qr: lambda: info, - MO.Chol: lambda: info, - MO.Svd: lambda: info, - }[self]() - - -#################### -# - Node -#################### class MapMathNode(base.MaxwellSimNode): r"""Applies a function by-structure to the data. @@ -495,7 +151,7 @@ class MapMathNode(base.MaxwellSimNode): return info return None - operation: MapOperation = bl_cache.BLField( + operation: math_system.MapOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -504,7 +160,9 @@ class MapMathNode(base.MaxwellSimNode): if self.expr_info is not None: return [ operation.bl_enum_element(i) - for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info)) + for i, operation in enumerate( + math_system.MapOperation.by_expr_info(self.expr_info) + ) ] return [] @@ -513,7 +171,7 @@ class MapMathNode(base.MaxwellSimNode): #################### def draw_label(self): if self.operation is not None: - return 'Map: ' + MapOperation.to_name(self.operation) + return 'Map: ' + math_system.MapOperation.to_name(self.operation) return self.bl_label diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py index f59bde4..920f863 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -14,354 +14,24 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import enum +"""Implements the `OperateMathNode`. + +See `blender_maxwell.maxwell_sim_nodes.math_system` for the actual mathematics implementation. +""" + import typing as typ import bpy -import jax.numpy as jnp -import sympy as sp -import sympy.physics.quantum as spq -import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class BinaryOperation(enum.StrEnum): - """Valid operations for the `OperateMathNode`. - - Attributes: - Mul: Scalar multiplication. - Div: Scalar division. - Pow: Scalar exponentiation. - Add: Elementwise addition. - Sub: Elementwise subtraction. - HadamMul: Elementwise multiplication (hadamard product). - HadamPow: Principled shape-aware exponentiation (hadamard power). - Atan2: Quadrant-respecting 2D arctangent. - VecVecDot: Dot product for identically shaped vectors w/transpose. - Cross: Cross product between identically shaped 3D vectors. - VecVecOuter: Vector-vector outer product. - LinSolve: Solve a linear system. - LsqSolve: Minimize error of an underdetermined linear system. - VecMatOuter: Vector-matrix outer product. - MatMatDot: Matrix-matrix dot product. - """ - - # Number | Number - Mul = enum.auto() - Div = enum.auto() - Pow = enum.auto() - - # Elements | Elements - Add = enum.auto() - Sub = enum.auto() - HadamMul = enum.auto() - # HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic. - Atan2 = enum.auto() - - # Vector | Vector - VecVecDot = enum.auto() - Cross = enum.auto() - VecVecOuter = enum.auto() - - # Matrix | Vector - LinSolve = enum.auto() - LsqSolve = enum.auto() - - # Vector | Matrix - VecMatOuter = enum.auto() - - # Matrix | Matrix - MatMatDot = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - BO = BinaryOperation - return { - # Number | Number - BO.Mul: 'ℓ · r', - BO.Div: 'ℓ / r', - BO.Pow: 'ℓ ^ r', - # Elements | Elements - BO.Add: 'ℓ + r', - BO.Sub: 'ℓ - r', - BO.HadamMul: '𝐋 ⊙ 𝐑', - # BO.HadamPow: '𝐥 ⊙^ 𝐫', - BO.Atan2: 'atan2(ℓ:x, r:y)', - # Vector | Vector - BO.VecVecDot: '𝐥 · 𝐫', - BO.Cross: 'cross(𝐥,𝐫)', - BO.VecVecOuter: '𝐥 ⊗ 𝐫', - # Matrix | Vector - BO.LinSolve: '𝐋 ∖ 𝐫', - BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂', - # Vector | Matrix - BO.VecMatOuter: '𝐋 ⊗ 𝐫', - # Matrix | Matrix - BO.MatMatDot: '𝐋 · 𝐑', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - BO = BinaryOperation - return ( - str(self), - BO.to_name(self), - BO.to_name(self), - BO.to_icon(self), - i, - ) - - #################### - # - Ops from Shape - #################### - @staticmethod - def by_infos(info_l: int, info_r: int) -> list[typ.Self]: - """Deduce valid binary operations from the shapes of the inputs.""" - BO = BinaryOperation - - ops_el_el = [ - BO.Add, - BO.Sub, - BO.HadamMul, - # BO.HadamPow, - ] - - outl = info_l.output - outr = info_r.output - match (outl.shape_len, outr.shape_len): - # Number | * - ## Number | Number - case (0, 0): - ops = [ - BO.Add, - BO.Sub, - BO.Mul, - ] - - # Check Non-Zero Right Hand Side - ## -> Obviously, we can't ever divide by zero. - ## -> Sympy's assumptions system must always guarantee rhs != 0. - ## -> If it can't, then we simply don't expose division. - ## -> The is_zero assumption must be provided elsewhere. - ## -> NOTE: This may prevent some valid uses of division. - ## -> Watch out for "division is missing" bugs. - if info_r.output.is_nonzero: - ops.append(BO.Div) - - if ( - info_l.output.physical_type == spux.PhysicalType.Length - and info_l.output.unit == info_r.output.unit - ): - ops += [BO.Atan2] - - return [*ops, BO.Pow] - - ## Number | Vector - case (0, 1): - return [BO.Mul] # , BO.HadamPow] - - ## Number | Matrix - case (0, 2): - return [BO.Mul] # , BO.HadamPow] - - # Vector | * - ## Vector | Number - case (1, 0): - return [BO.Mul] # , BO.HadamPow] - - ## Vector | Vector - case (1, 1): - ops = [] - - # Vector | Vector - ## -> Dot: Convenience; utilize special vec-vec dot w/transp. - if outl.rows > outl.cols and outr.rows > outr.cols: - ops += [BO.VecVecDot, BO.VecVecOuter] - - # Covector | Vector - ## -> Dot: Directly use matrix-matrix dot, as it's now correct. - if outl.rows < outl.cols and outr.rows > outr.cols: - ops += [BO.MatMatDot, BO.VecVecOuter] - - # Vector | Covector - ## -> Dot: Directly use matrix-matrix dot, as it's now correct. - ## -> These are both the same operation, in this case. - if outl.rows > outl.cols and outr.rows < outr.cols: - ops += [BO.MatMatDot, BO.VecVecOuter] - - # Covector | Covector - ## -> Dot: Convenience; utilize special vec-vec dot w/transp. - if outl.rows < outl.cols and outr.rows < outr.cols: - ops += [BO.VecVecDot, BO.VecVecOuter] - - # Cross Product - ## -> Enforce that both are 3x1 or 1x3. - ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross - if (outl.rows == 3 and outr.rows == 3) or ( - outl.cols == 3 and outl.cols == 3 - ): - ops += [BO.Cross] - - return ops_el_el + ops - - ## Vector | Matrix - case (1, 2): - return [BO.VecMatOuter] - - # Matrix | * - ## Matrix | Number - case (2, 0): - return [BO.Mul] # , BO.HadamPow] - - ## Matrix | Vector - case (2, 1): - prepend_ops = [] - - # Mat-Vec Dot: Enforce RHS Column Vector - if outr.rows > outl.cols: - prepend_ops += [BO.MatMatDot] - - return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow] - - ## Matrix | Matrix - case (2, 2): - return [*ops_el_el, BO.MatMatDot] - - return [] - - #################### - # - Function Properties - #################### - @property - def sp_func(self): - """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs.""" - BO = BinaryOperation - - ## TODO: Make this compatible with sp.Matrix inputs - return { - # Number | Number - BO.Mul: lambda exprs: exprs[0] * exprs[1], - BO.Div: lambda exprs: exprs[0] / exprs[1], - BO.Pow: lambda exprs: exprs[0] ** exprs[1], - # Elements | Elements - BO.Add: lambda exprs: exprs[0] + exprs[1], - BO.Sub: lambda exprs: exprs[0] - exprs[1], - BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]), - # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), - BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), - # Vector | Vector - BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0], - BO.Cross: lambda exprs: exprs[0].cross(exprs[1]), - BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T, - # Matrix | Vector - BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]), - BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]), - # Vector | Matrix - BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]), - # Matrix | Matrix - BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1], - }[self] - - @property - def unit_func(self): - """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output.""" - BO = BinaryOperation - - ## TODO: Make this compatible with sp.Matrix inputs - return { - # Number | Number - BO.Mul: BO.Mul.sp_func, - BO.Div: BO.Div.sp_func, - BO.Pow: BO.Pow.sp_func, - # Elements | Elements - BO.Add: BO.Add.sp_func, - BO.Sub: BO.Sub.sp_func, - BO.HadamMul: BO.Mul.sp_func, - # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), - BO.Atan2: lambda _: spu.radian, - # Vector | Vector - BO.VecVecDot: BO.Mul.sp_func, - BO.Cross: BO.Mul.sp_func, - BO.VecVecOuter: BO.Mul.sp_func, - # Matrix | Vector - ## -> A,b in Ax = b have units, and the equality must hold. - ## -> Therefore, A \ b must have the units [b]/[A]. - BO.LinSolve: lambda exprs: exprs[1] / exprs[0], - BO.LsqSolve: lambda exprs: exprs[1] / exprs[0], - # Vector | Matrix - BO.VecMatOuter: BO.Mul.sp_func, - # Matrix | Matrix - BO.MatMatDot: BO.Mul.sp_func, - }[self] - - @property - def jax_func(self): - """Deduce an appropriate jax-based function that implements the binary operation for array inputs.""" - ## TODO: Scale the units of one side to the other. - BO = BinaryOperation - - return { - # Number | Number - BO.Mul: lambda exprs: exprs[0] * exprs[1], - BO.Div: lambda exprs: exprs[0] / exprs[1], - BO.Pow: lambda exprs: exprs[0] ** exprs[1], - # Elements | Elements - BO.Add: lambda exprs: exprs[0] + exprs[1], - BO.Sub: lambda exprs: exprs[0] - exprs[1], - BO.HadamMul: lambda exprs: exprs[0] * exprs[1], - # BO.HadamPow: lambda exprs: exprs[0] ** exprs[1], - BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]), - # Vector | Vector - BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]), - BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]), - BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), - # Matrix | Vector - BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), - BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), - # Vector | Matrix - BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), - # Matrix | Matrix - BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]), - }[self] - - #################### - # - InfoFlow Transform - #################### - def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): - """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression. - - Warnings: - `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r). - - If not, bad things will happen. - """ - return info_l.operate_output( - info_r, - lambda a, b: self.sp_func([a, b]), - lambda a, b: self.unit_func([a, b]), - ) - - -#################### -# - Node -#################### class OperateMathNode(base.MaxwellSimNode): r"""Applies a binary function between two expressions. @@ -386,7 +56,7 @@ class OperateMathNode(base.MaxwellSimNode): } #################### - # - Properties + # - Properties: Incoming InfoFlows #################### @events.on_value_changed( # Trigger @@ -415,6 +85,7 @@ class OperateMathNode(base.MaxwellSimNode): @bl_cache.cached_bl_property() def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None: + """Computed `InfoFlow`s of both expressions.""" info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info) info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info) @@ -426,19 +97,18 @@ class OperateMathNode(base.MaxwellSimNode): return None - operation: BinaryOperation = bl_cache.BLField( + #################### + # - Property: Operation + #################### + operation: math_system.BinaryOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_infos'}, ) def search_operations(self) -> list[ct.BLEnumElement]: + """Retrieve valid operations based on the input `InfoFlow`s.""" if self.expr_infos is not None: - return [ - operation.bl_enum_element(i) - for i, operation in enumerate( - BinaryOperation.by_infos(*self.expr_infos) - ) - ] + return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos) return [] #################### @@ -451,7 +121,7 @@ class OperateMathNode(base.MaxwellSimNode): Called by Blender to determine the text to place in the node's header. """ if self.operation is not None: - return 'Op: ' + BinaryOperation.to_name(self.operation) + return 'Op: ' + math_system.BinaryOperation.to_name(self.operation) return self.bl_label @@ -464,7 +134,7 @@ class OperateMathNode(base.MaxwellSimNode): layout.prop(self, self.blfields['operation'], text='') #################### - # - FlowKind.Value|Func + # - FlowKind.Value #################### @events.computes_output_socket( 'Expr', @@ -477,20 +147,22 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_value(self, props: dict, input_sockets: dict): - operation = props['operation'] + """Binary operation on two symbolic input expressions.""" expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] has_expr_l_value = not ct.FlowSignal.check(expr_l) has_expr_r_value = not ct.FlowSignal.check(expr_r) - # Compute Sympy Function - ## -> The operation enum directly provides the appropriate function. + operation = props['operation'] if has_expr_l_value and has_expr_r_value and operation is not None: return operation.sp_func([expr_l, expr_r]) return ct.FlowSignal.FlowPending + #################### + # - FlowKind.Func + #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Func, @@ -505,10 +177,7 @@ class OperateMathNode(base.MaxwellSimNode): output_socket_kinds={'Expr': ct.FlowKind.Info}, ) def compute_func(self, props, input_sockets, output_sockets): - operation = props['operation'] - if operation is None: - return ct.FlowSignal.FlowPending - + """Binary operation on two lazy-defined input expressions.""" expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] output_info = output_sockets['Expr'] @@ -517,14 +186,9 @@ class OperateMathNode(base.MaxwellSimNode): has_expr_r = not ct.FlowSignal.check(expr_r) has_output_info = not ct.FlowSignal.check(output_info) - # Compute Jax Function - ## -> The operation enum directly provides the appropriate function. - if has_expr_l and has_expr_r and has_output_info: - return (expr_l | expr_r).compose_within( - operation.jax_func, - enclosing_func_output=output_info.output, - supports_jax=True, - ) + operation = props['operation'] + if operation is not None and has_expr_l and has_expr_r and has_output_info: + return self.operation.transform_funcs(expr_l, expr_r) return ct.FlowSignal.FlowPending #################### @@ -541,22 +205,17 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_info(self, props, input_sockets) -> ct.InfoFlow: - BO = BinaryOperation - - operation = props['operation'] + """Transform the input information of both lazy inputs.""" info_l = input_sockets['Expr L'] info_r = input_sockets['Expr R'] has_info_l = not ct.FlowSignal.check(info_l) has_info_r = not ct.FlowSignal.check(info_r) - # Compute Info - ## -> The operation enum directly provides the appropriate transform. + operation = props['operation'] if ( - has_info_l - and has_info_r - and operation is not None - and operation in BO.by_infos(info_l, info_r) + has_info_l and has_info_r and operation is not None + # and operation in BO.by_infos(info_l, info_r) ): return operation.transform_infos(info_l, info_r) @@ -576,15 +235,14 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal: - operation = props['operation'] + """Merge the lazy input parameters.""" params_l = input_sockets['Expr L'] params_r = input_sockets['Expr R'] has_params_l = not ct.FlowSignal.check(params_l) has_params_r = not ct.FlowSignal.check(params_r) - # Compute Params - ## -> Operations don't add new parameters, so just concatenate L|R. + operation = props['operation'] if has_params_l and has_params_r and operation is not None: return params_l | params_r diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py index 6744f7d..aa6ed07 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py @@ -20,334 +20,18 @@ import enum import typing as typ import bpy -import jax.numpy as jnp -import jaxtyping as jtyp import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class TransformOperation(enum.StrEnum): - """Valid operations for the `TransformMathNode`. - - Attributes: - FreqToVacWL: Transform an frequency dimension to vacuum wavelength. - VacWLToFreq: Transform a vacuum wavelength dimension to frequency. - ConvertIdxUnit: Convert the unit of a dimension to a compatible unit. - SetIdxUnit: Set all properties of a dimension. - FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it. - **For 2D integer-indexed data only**. - - IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type. - DimToVec: Fold the last dimension into the scalar output, creating a vector output type. - DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type. - FT: Compute the 1D fourier transform along a dimension. - New dimensional bounds are computing using the Nyquist Limit. - For higher dimensions, simply repeat along more dimensions. - InvFT1D: Compute the inverse 1D fourier transform along a dimension. - New dimensional bounds are computing using the Nyquist Limit. - For higher dimensions, simply repeat along more dimensions. - """ - - # Covariant Transform - FreqToVacWL = enum.auto() - VacWLToFreq = enum.auto() - ConvertIdxUnit = enum.auto() - SetIdxUnit = enum.auto() - FirstColToFirstIdx = enum.auto() - - # Fold - IntDimToComplex = enum.auto() - DimToVec = enum.auto() - DimsToMat = enum.auto() - - # Fourier - FT1D = enum.auto() - InvFT1D = enum.auto() - - # TODO: Affine - ## TODO - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - TO = TransformOperation - return { - # Covariant Transform - TO.FreqToVacWL: '𝑓 → λᵥ', - TO.VacWLToFreq: 'λᵥ → 𝑓', - TO.ConvertIdxUnit: 'Convert Dim', - TO.SetIdxUnit: 'Set Dim', - TO.FirstColToFirstIdx: '1st Col → 1st Dim', - # Fold - TO.IntDimToComplex: '→ ℂ', - TO.DimToVec: '→ Vector', - TO.DimsToMat: '→ Matrix', - ## TODO: Vector to new last-dim integer - ## TODO: Matrix to two last-dim integers - # Fourier - TO.FT1D: 'FT', - TO.InvFT1D: 'iFT', - }[value] - - @property - def name(self) -> str: - return TransformOperation.to_name(self) - - @staticmethod - def to_icon(_: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - TO = TransformOperation - return ( - str(self), - TO.to_name(self), - TO.to_name(self), - TO.to_icon(self), - i, - ) - - #################### - # - Methods - #################### - def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: - TO = TransformOperation - match self: - case TO.FreqToVacWL: - return [ - dim - for dim in info.dims - if dim.physical_type is spux.PhysicalType.Freq - ] - - case TO.VacWLToFreq: - return [ - dim - for dim in info.dims - if dim.physical_type is spux.PhysicalType.Length - ] - - case TO.ConvertIdxUnit: - return [ - dim - for dim in info.dims - if not info.has_idx_labels(dim) - and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None - ] - - case TO.SetIdxUnit: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - ## ColDimToComplex: Implicit Last Dimension - ## DimToVec: Implicit Last Dimension - ## DimsToMat: Implicit Last 2 Dimensions - - case TO.FT1D | TO.InvFT1D: - # Filter by Axis Uniformity - ## -> FT requires uniform axis (aka. must be RangeFlow). - ## -> NOTE: If FT isn't popping up, check ExtractDataNode. - return [dim for dim in info.dims if info.is_idx_uniform(dim)] - - return [] - - @staticmethod - def by_info(info: ct.InfoFlow) -> list[typ.Self]: - TO = TransformOperation - operations = [] - - # Covariant Transform - ## Freq -> VacWL - if TO.FreqToVacWL.valid_dims(info): - operations += [TO.FreqToVacWL] - - ## VacWL -> Freq - if TO.VacWLToFreq.valid_dims(info): - operations += [TO.VacWLToFreq] - - ## Convert Index Unit - if TO.ConvertIdxUnit.valid_dims(info): - operations += [TO.ConvertIdxUnit] - - if TO.SetIdxUnit.valid_dims(info): - operations += [TO.SetIdxUnit] - - ## Column to First Index (Array) - if ( - len(info.dims) == 2 # noqa: PLR2004 - and info.first_dim.mathtype is spux.MathType.Integer - and info.last_dim.mathtype is spux.MathType.Integer - and info.output.shape_len == 0 - ): - operations += [TO.FirstColToFirstIdx] - - # Fold - ## Last Dim -> Complex - if ( - len(info.dims) >= 1 - and ( - info.output.mathtype - in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real] - ) - and info.last_dim.mathtype is spux.MathType.Integer - and info.has_idx_labels(info.last_dim) - and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004 - ): - operations += [TO.IntDimToComplex] - - ## Last Dim -> Vector - if len(info.dims) >= 1 and info.output.shape_len == 0: - operations += [TO.DimToVec] - - ## Last Dim -> Matrix - if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004 - operations += [TO.DimsToMat] - - # Fourier - if TO.FT1D.valid_dims(info): - operations += [TO.FT1D] - - if TO.InvFT1D.valid_dims(info): - operations += [TO.InvFT1D] - - return operations - - #################### - # - Function Properties - #################### - def jax_func(self, axis: int | None = None): - TO = TransformOperation - return { - # Covariant Transform - ## -> Freq <-> WL is a rescale (noop) AND flip (not noop). - TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis), - TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis), - TO.ConvertIdxUnit: lambda expr: expr, - TO.SetIdxUnit: lambda expr: expr, - TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1), - # Fold - ## -> To Complex: This should generally be a no-op. - TO.IntDimToComplex: lambda expr: jnp.squeeze( - expr.view(dtype=jnp.complex64), axis=-1 - ), - TO.DimToVec: lambda expr: expr, - TO.DimsToMat: lambda expr: expr, - # Fourier - TO.FT1D: lambda expr: jnp.fft(expr, axis=axis), - TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis), - }[self] - - def transform_info( - self, - info: ct.InfoFlow, - dim: sim_symbols.SimSymbol | None = None, - data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None, - new_dim_name: str | None = None, - unit: spux.Unit | None = None, - physical_type: spux.PhysicalType | None = None, - ) -> ct.InfoFlow: - TO = TransformOperation - return { - # Covariant Transform - TO.FreqToVacWL: lambda: info.replace_dim( - (f_dim := dim), - sim_symbols.wl(unit), - info.dims[f_dim].rescale( - lambda el: sci_constants.vac_speed_of_light / el, - reverse=True, - new_unit=unit, - ), - ), - TO.VacWLToFreq: lambda: info.replace_dim( - (wl_dim := dim), - sim_symbols.freq(unit), - info.dims[wl_dim].rescale( - lambda el: sci_constants.vac_speed_of_light / el, - reverse=True, - new_unit=unit, - ), - ), - TO.ConvertIdxUnit: lambda: info.replace_dim( - dim, - dim.update(unit=unit), - ( - info.dims[dim].rescale_to_unit(unit) - if info.has_idx_discrete(dim) - else None ## Continuous -- dim SimSymbol already scaled - ), - ), - TO.SetIdxUnit: lambda: info.replace_dim( - dim, - dim.update( - sym_name=new_dim_name, - physical_type=physical_type, - unit=unit, - ), - ( - info.dims[dim].correct_unit(unit) - if info.has_idx_discrete(dim) - else None ## Continuous -- dim SimSymbol already scaled - ), - ), - TO.FirstColToFirstIdx: lambda: info.replace_dim( - info.first_dim, - info.first_dim.update( - sym_name=new_dim_name, - mathtype=spux.MathType.from_jax_array(data_col), - physical_type=physical_type, - unit=unit, - ), - ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)), - ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)), - # Fold - TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( - mathtype=spux.MathType.Complex - ), - TO.DimToVec: lambda: info.fold_last_input(), - TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), - # Fourier - TO.FT1D: lambda: info.replace_dim( - dim, - [ - # FT'ed Unit: Reciprocal of the Original Unit - dim.update( - unit=1 / dim.unit if dim.unit is not None else 1 - ), ## TODO: Okay to not scale interval? - # FT'ed Bounds: Reciprocal of the Original Unit - info.dims[dim].bound_fourier_transform, - ], - ), - TO.InvFT1D: lambda: info.replace_dim( - info.last_dim, - [ - # FT'ed Unit: Reciprocal of the Original Unit - dim.update( - unit=1 / dim.unit if dim.unit is not None else 1 - ), ## TODO: Okay to not scale interval? - # FT'ed Bounds: Reciprocal of the Original Unit - ## -> Note the midpoint may revert to 0. - ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more. - info.dims[dim].bound_inv_fourier_transform, - ], - ), - }[self]() - - -#################### -# - Node -#################### class TransformMathNode(base.MaxwellSimNode): r"""Applies a function to the array as a whole, with arbitrary results. @@ -409,7 +93,7 @@ class TransformMathNode(base.MaxwellSimNode): #################### # - Properties: Operation #################### - operation: TransformOperation = bl_cache.BLField( + operation: math_system.TransformOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -419,7 +103,7 @@ class TransformMathNode(base.MaxwellSimNode): return [ operation.bl_enum_element(i) for i, operation in enumerate( - TransformOperation.by_info(self.expr_info) + math_system.TransformOperation.by_info(self.expr_info) ) ] return [] @@ -461,7 +145,7 @@ class TransformMathNode(base.MaxwellSimNode): ) def search_units(self) -> list[ct.BLEnumElement]: - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: # Covariant Transform case TO.ConvertIdxUnit if self.dim is not None: @@ -521,7 +205,7 @@ class TransformMathNode(base.MaxwellSimNode): return spux.sp_to_str(self.new_unit) def draw_label(self): - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: case TO.FreqToVacWL if self.dim is not None: return f'T: {self.dim.name_pretty} | 𝑓 → {self.new_unit_str}' @@ -556,7 +240,7 @@ class TransformMathNode(base.MaxwellSimNode): def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: case TO.ConvertIdxUnit: row = layout.row(align=True) @@ -613,7 +297,7 @@ class TransformMathNode(base.MaxwellSimNode): self, props, input_sockets, output_sockets ) -> ct.FuncFlow | ct.FlowSignal: """Transform the input `InfoFlow` depending on the transform operation.""" - TO = TransformOperation + TO = math_system.TransformOperation lazy_func = input_sockets['Expr'][ct.FlowKind.Func] info = input_sockets['Expr'][ct.FlowKind.Info] @@ -662,7 +346,7 @@ class TransformMathNode(base.MaxwellSimNode): self, props: dict, input_sockets: dict ) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]: """Transform the input `InfoFlow` depending on the transform operation.""" - TO = TransformOperation + TO = math_system.TransformOperation operation = props['operation'] info = input_sockets['Expr'][ct.FlowKind.Info] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index ad44ef8..dda78ff 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -24,7 +24,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py index 2f8f276..f03eb77 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py @@ -22,7 +22,7 @@ import bpy import sympy as sp import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py index 91faec1..1490bdd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py @@ -22,7 +22,7 @@ import bpy import sympy as sp import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py index aa94c09..9a94de8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py @@ -19,7 +19,7 @@ import inspect import typing as typ from types import MappingProxyType -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .. import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py index dba473a..11068e3 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py @@ -21,7 +21,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py index db9df67..213746a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py @@ -16,12 +16,13 @@ import enum import typing as typ +from fractions import Fraction import bpy import sympy as sp from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets @@ -50,6 +51,8 @@ class SymbolConstantNode(base.MaxwellSimNode): ) size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar) + ## Use of NumberSize1D implicitly guarantees UI-realizability later. + mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real) physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) @@ -109,31 +112,110 @@ class SymbolConstantNode(base.MaxwellSimNode): preview_value_re: float = bl_cache.BLField(0.0) preview_value_im: float = bl_cache.BLField(0.0) - #################### - # - Computed Properties - #################### @bl_cache.cached_bl_property( depends_on={ - 'sym_name', - 'size', 'mathtype', - 'physical_type', - 'unit', 'interval_finite_z', 'interval_finite_q', 'interval_finite_re', - 'interval_inf', - 'interval_closed', 'interval_finite_im', - 'interval_inf_im', - 'interval_closed_im', + } + ) + def interval_finite( + self, + ) -> ( + tuple[int | Fraction | float, int | Fraction | float] + | tuple[tuple[float, float], tuple[float, float]] + ): + """Return the appropriate finite interval from the UI, as guided by `self.mathtype`.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer: + return self.interval_finite_z + case MT.Rational: + return [Fraction(*q) for q in self.interval_finite_q] + case MT.Real: + return self.interval_finite_re + case MT.Complex: + return (self.interval_finite_re, self.interval_finite_im) + + @bl_cache.cached_bl_property( + depends_on={ + 'mathtype', 'preview_value_z', 'preview_value_q', 'preview_value_re', 'preview_value_im', } ) + def preview_value( + self, + ) -> int | Fraction | float | complex: + """Return the appropriate finite interval from the UI, as guided by `self.mathtype`.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer: + return self.preview_value_z + case MT.Rational: + return Fraction(*self.preview_value_q) + case MT.Real: + return self.preview_value_re + case MT.Complex: + return complex(self.preview_value_re, self.preview_value_im) + + @bl_cache.cached_bl_property( + depends_on={ + 'mathtype', + 'interval_finite', + 'interval_inf', + 'interval_inf_im', + 'interval_closed', + 'interval_closed_im', + } + ) + def domain( + self, + ) -> sp.Interval | sp.sets.fancysets.CartesianComplexRegion: + """Deduce the domain specified in the UI.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer | MT.Real | MT.Rational: + return sim_symbols.mk_interval( + self.interval_finite, + self.interval_inf, + self.interval_closed, + ) + + case MT.Complex: + region = self.interval_finite + domain_re = sim_symbols.mk_interval( + region[0], + self.interval_inf, + self.interval_closed, + ) + domain_im = sim_symbols.mk_interval( + region[1], + self.interval_inf_im, + self.interval_closed_im, + ) + return sp.ComplexRegion(domain_re, domain_im, polar=False) + + #################### + # - Computed Properties + #################### + @bl_cache.cached_bl_property( + depends_on={ + 'sym_name', + 'mathtype', + 'physical_type', + 'unit', + 'size', + 'domain', + 'preview_value', + } + ) def symbol(self) -> sim_symbols.SimSymbol: + """Generate the `SimSymbol` matching the user-specification.""" return sim_symbols.SimSymbol( sym_name=self.sym_name, mathtype=self.mathtype, @@ -141,18 +223,8 @@ class SymbolConstantNode(base.MaxwellSimNode): unit=self.unit, rows=self.size.rows, cols=self.size.cols, - interval_finite_z=self.interval_finite_z, - interval_finite_q=self.interval_finite_q, - interval_finite_re=self.interval_finite_re, - interval_inf=self.interval_inf, - interval_closed=self.interval_closed, - interval_finite_im=self.interval_finite_im, - interval_inf_im=self.interval_inf_im, - interval_closed_im=self.interval_closed_im, - preview_value_z=self.preview_value_z, - preview_value_q=self.preview_value_q, - preview_value_re=self.preview_value_re, - preview_value_im=self.preview_value_im, + domain=self.domain, + preview_value=self.preview_value, ) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py index 252dbed..e9134fc 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py @@ -23,7 +23,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py index 2acafa7..8b6fc43 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py @@ -23,7 +23,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py index 68d6d49..55f96c6 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py @@ -23,7 +23,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger, sci_constants -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py index da00004..e9c275b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py @@ -25,7 +25,7 @@ from tidy3d.material_library.material_library import MaterialItem as Tidy3DMediu from tidy3d.material_library.material_library import VariantItem as Tidy3DMediumVariant from blender_maxwell.utils import bl_cache, logger, sci_constants -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py index 7d40ba2..d9ca2e8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py @@ -23,7 +23,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py index 319ce0b..d40e0ab 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py @@ -23,7 +23,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py index b5a287c..85d845b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py @@ -20,7 +20,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py index 4cf1503..963186b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py @@ -20,7 +20,7 @@ from pathlib import Path import bpy from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py index 585793b..0ad5d02 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py @@ -21,7 +21,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py index 0197f5f..7c525ff 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py @@ -22,7 +22,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py index 2bbd6b0..94eb827 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py index d1d3ab5..3f60b9e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py index 0ed364b..0457790 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py index 1f48f31..c307143 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py @@ -28,7 +28,7 @@ from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py index 27369d5..6669a50 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py @@ -20,7 +20,7 @@ import sympy as sp import sympy.physics.units as spu import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import bl_socket_map, managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py index 2109185..754ac29 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py @@ -24,7 +24,7 @@ import tidy3d.plugins.adjoint as tdadj from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py index 072756d..1177dbd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py @@ -21,7 +21,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py index d5a83b4..8a51183 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py @@ -21,7 +21,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index 84e077b..6c553a7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -24,7 +24,7 @@ import pydantic as pyd import sympy as sp from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .. import contracts as ct from . import base @@ -155,12 +155,14 @@ class ExprBLSocket(base.MaxwellSimSocket): unit=self.unit, rows=self.size.rows, cols=self.size.cols, + is_constant=True, + ## TODO: Should we set preview values exclude_zero=( not self.value.is_zero if self.value.is_zero is not None else False ), - ## TODO: Does this work for matrix elements? + ## TODO: Does this 0-check work for matrix elements? ) case ct.FlowKind.Range if self.symbols: @@ -208,7 +210,7 @@ class ExprBLSocket(base.MaxwellSimSocket): sim_symbols.SimSymbolName.Expr ) output_name: sim_symbols.SimSymbolName = bl_cache.BLField( - sim_symbols.SimSymbolName.Expr + sim_symbols.SimSymbolName.Constant ) symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) @@ -441,12 +443,11 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False): """Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`.""" - if self.unit is not None: - pyvalue = spux.sympy_to_python(spux.scale_to_unit(expr, self.unit)) - else: - pyvalue = spux.sympy_to_python(expr) + pyvalue = spux.scale_to_unit(expr, self.unit) # Cast complex -> tuple[float, float] + ## -> We can't set complex to BLProps. + ## -> We must deconstruct it appropriately. if isinstance(pyvalue, complex) or ( isinstance(pyvalue, int | float) and force_complex ): diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py index 2c2119f..aa41351 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py @@ -24,7 +24,7 @@ import tidy3d as td import tidy3d.plugins.adjoint as tdadj from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from .. import base diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py index 7cf5cf0..a39fcbf 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py @@ -19,7 +19,7 @@ import sympy as sp import sympy.physics.optics.polarization as spo_pol import sympy.physics.units as spu -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from .. import base diff --git a/src/blender_maxwell/utils/__init__.py b/src/blender_maxwell/utils/__init__.py index a44be30..e1e2d51 100644 --- a/src/blender_maxwell/utils/__init__.py +++ b/src/blender_maxwell/utils/__init__.py @@ -17,22 +17,22 @@ from ..nodeps.utils import blender_type_enum, pydeps from . import ( bl_cache, - extra_sympy_units, image_ops, logger, sci_constants, serialize, staticproperty, + sympy_extra, ) __all__ = [ 'blender_type_enum', 'pydeps', 'bl_cache', - 'extra_sympy_units', 'image_ops', 'logger', 'sci_constants', 'serialize', 'staticproperty', + 'sympy_extra', ] diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py deleted file mode 100644 index fa31709..0000000 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ /dev/null @@ -1,1699 +0,0 @@ -# blender_maxwell -# Copyright (C) 2024 blender_maxwell Project Contributors -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -"""Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system. - -Attributes: - UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit. - This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`. - SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions. - For simple `isinstance` checks, this should be preferred, as it is most performant. - For general use, `SympyExpr` should be preferred. - SympyExpr: A `SympyType` that is compatible with `pydantic`, including serialization/deserialization. - Should be used via the `ConstrSympyExpr`, which also adds expression validation. -""" - -import enum -import functools -import sys -import typing as typ -from fractions import Fraction - -import jax -import jax.numpy as jnp -import jaxtyping as jtyp -import pydantic as pyd -import sympy as sp -import sympy.physics.units as spu -import typing_extensions as typx -from pydantic_core import core_schema as pyd_core_schema - -from blender_maxwell import contracts as ct - -from . import logger -from .staticproperty import staticproperty - -log = logger.get(__name__) - -SympyType = ( - sp.Basic - | sp.Expr - | sp.MatrixBase - | sp.MutableDenseMatrix - | spu.Quantity - | spu.Dimension -) - - -#################### -# - Math Type -#################### -class MathType(enum.StrEnum): - """Type identifiers that encompass common sets of mathematical objects.""" - - Integer = enum.auto() - Rational = enum.auto() - Real = enum.auto() - Complex = enum.auto() - - @staticmethod - def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None: - if MathType.Complex in mathtypes: - return MathType.Complex - if MathType.Real in mathtypes: - return MathType.Real - if MathType.Rational in mathtypes: - return MathType.Rational - if MathType.Integer in mathtypes: - return MathType.Integer - - if optional: - return None - - msg = f"Can't combine mathtypes {mathtypes}" - raise ValueError(msg) - - def is_compatible(self, other: typ.Self) -> bool: - MT = MathType - return ( - other - in { - MT.Integer: [MT.Integer], - MT.Rational: [MT.Integer, MT.Rational], - MT.Real: [MT.Integer, MT.Rational, MT.Real], - MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex], - }[self] - ) - - def coerce_compatible_pyobj( - self, pyobj: bool | int | Fraction | float | complex - ) -> int | Fraction | float | complex: - MT = MathType - match self: - case MT.Integer: - return int(pyobj) - case MT.Rational if isinstance(pyobj, int): - return Fraction(pyobj, 1) - case MT.Rational if isinstance(pyobj, Fraction): - return pyobj - case MT.Real: - return float(pyobj) - case MT.Complex if isinstance(pyobj, int | Fraction): - return complex(float(pyobj), 0) - case MT.Complex if isinstance(pyobj, float): - return complex(pyobj, 0) - - @staticmethod - def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: - if isinstance(sp_obj, sp.MatrixBase): - return MathType.combine( - *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] - ) - - if sp_obj.is_integer: - return MathType.Integer - if sp_obj.is_rational: - return MathType.Rational - if sp_obj.is_real: - return MathType.Real - if sp_obj.is_complex: - return MathType.Complex - - # Infinities - if sp_obj in [sp.oo, -sp.oo]: - return MathType.Real ## TODO: Strictly, could be ex. integer... - if sp_obj in [sp.zoo, -sp.zoo]: - return MathType.Complex - - if optional: - return None - - msg = f"Can't determine MathType from sympy object: {sp_obj}" - raise ValueError(msg) - - @staticmethod - def from_pytype(dtype: type) -> type: - return { - int: MathType.Integer, - Fraction: MathType.Rational, - float: MathType.Real, - complex: MathType.Complex, - }[dtype] - - @staticmethod - def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type: - """Deduce the MathType corresponding to a JAX array. - - We go about this by leveraging that: - - `data` is of a homogeneous type. - - `data.item(0)` returns a single element of the array w/pure-python type. - - By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency. - - Notes: - Should also work with numpy arrays. - """ - return MathType.from_pytype(type(data.item(0))) - - @staticmethod - def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None: - if isinstance(obj, bool | int | Fraction | float | complex): - return 'pytype' - if isinstance(obj, sp.Basic | sp.MatrixBase | sp.MutableDenseMatrix): - return 'expr' - - return None - - @property - def pytype(self) -> type: - MT = MathType - return { - MT.Integer: int, - MT.Rational: Fraction, - MT.Real: float, - MT.Complex: complex, - }[self] - - @property - def symbolic_set(self) -> type: - MT = MathType - return { - MT.Integer: sp.Integers, - MT.Rational: sp.Rationals, - MT.Real: sp.Reals, - MT.Complex: sp.Complexes, - }[self] - - @property - def inf_finite(self) -> type: - """Opinionated finite representation of "infinity" within this `MathType`. - - These are chosen using `sys.maxsize` and `sys.float_info`. - As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated. - - **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here. - - Notes: - Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. . - - These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`). - In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway. - - However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context. - """ - MT = MathType - Z = MT.Integer - R = MT.Integer - return { - MT.Integer: (-sys.maxsize, sys.maxsize), - MT.Rational: ( - Fraction(Z.inf_finite[0], 1), - Fraction(Z.inf_finite[1], 1), - ), - MT.Real: -(sys.float_info.min, sys.float_info.max), - MT.Complex: ( - complex(R.inf_finite[0], R.inf_finite[0]), - complex(R.inf_finite[1], R.inf_finite[1]), - ), - }[self] - - @property - def sp_symbol_a(self) -> type: - MT = MathType - return { - MT.Integer: sp.Symbol('a', integer=True), - MT.Rational: sp.Symbol('a', rational=True), - MT.Real: sp.Symbol('a', real=True), - MT.Complex: sp.Symbol('a', complex=True), - }[self] - - @staticmethod - def to_str(value: typ.Self) -> type: - return { - MathType.Integer: 'ℤ', - MathType.Rational: 'ℚ', - MathType.Real: 'ℝ', - MathType.Complex: 'ℂ', - }[value] - - @property - def label_pretty(self) -> str: - return MathType.to_str(self) - - @staticmethod - def to_name(value: typ.Self) -> str: - return MathType.to_str(value) - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - return ( - str(self), - MathType.to_name(self), - MathType.to_name(self), - MathType.to_icon(self), - i, - ) - - -#################### -# - Size: 1D -#################### -class NumberSize1D(enum.StrEnum): - """Valid 1D-constrained shape.""" - - Scalar = enum.auto() - Vec2 = enum.auto() - Vec3 = enum.auto() - Vec4 = enum.auto() - - @staticmethod - def to_name(value: typ.Self) -> str: - NS = NumberSize1D - return { - NS.Scalar: 'Scalar', - NS.Vec2: '2D', - NS.Vec3: '3D', - NS.Vec4: '4D', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - NS = NumberSize1D - return { - NS.Scalar: '', - NS.Vec2: '', - NS.Vec3: '', - NS.Vec4: '', - }[value] - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - return ( - str(self), - NumberSize1D.to_name(self), - NumberSize1D.to_name(self), - NumberSize1D.to_icon(self), - i, - ) - - @staticmethod - def has_shape(shape: tuple[int, ...] | None): - return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)] - - def supports_shape(self, shape: tuple[int, ...] | None): - NS = NumberSize1D - match self: - case NS.Scalar: - return shape is None - case NS.Vec2: - return shape in ((2,), (2, 1)) - case NS.Vec3: - return shape in ((3,), (3, 1)) - case NS.Vec4: - return shape in ((4,), (4, 1)) - - @staticmethod - def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: - NS = NumberSize1D - return { - None: NS.Scalar, - (2,): NS.Vec2, - (3,): NS.Vec3, - (4,): NS.Vec4, - (2, 1): NS.Vec2, - (3, 1): NS.Vec3, - (4, 1): NS.Vec4, - }[shape] - - @property - def rows(self): - NS = NumberSize1D - return { - NS.Scalar: 1, - NS.Vec2: 2, - NS.Vec3: 3, - NS.Vec4: 4, - }[self] - - @property - def cols(self): - return 1 - - @property - def shape(self): - NS = NumberSize1D - return { - NS.Scalar: None, - NS.Vec2: (2,), - NS.Vec3: (3,), - NS.Vec4: (4,), - }[self] - - -def symbol_range(sym: sp.Symbol) -> str: - return f'{sym.name} ∈ ' + ( - 'ℂ' - if sym.is_complex - else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) - ) - - -#################### -# - Symbol Sizes -#################### -class SimpleSize2D(enum.StrEnum): - """Simple subset of sizes for rank-2 tensors.""" - - Scalar = enum.auto() - - # Vectors - Vec2 = enum.auto() ## 2x1 - Vec3 = enum.auto() ## 3x1 - Vec4 = enum.auto() ## 4x1 - - # Covectors - CoVec2 = enum.auto() ## 1x2 - CoVec3 = enum.auto() ## 1x3 - CoVec4 = enum.auto() ## 1x4 - - # Square Matrices - Mat22 = enum.auto() ## 2x2 - Mat33 = enum.auto() ## 3x3 - Mat44 = enum.auto() ## 4x4 - - -#################### -# - Unit Dimensions -#################### -class DimsMeta(type): - def __getattr__(cls, attr: str) -> spu.Dimension: - if ( - attr in spu.definitions.dimension_definitions.__dir__() - and not attr.startswith('__') - ): - return getattr(spu.definitions.dimension_definitions, attr) - - raise AttributeError(name=attr, obj=Dims) - - -class Dims(metaclass=DimsMeta): - """Access `sympy.physics.units` dimensions with less hassle. - - Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. - - An `AttributeError` is raised if the unit cannot be found in `sympy`. - - Examples: - The objects returned are a direct alias to `sympy`, with less hassle: - ```python - assert Dims.length == ( - sympy.physics.units.definitions.dimension_definitions.length - ) - ``` - """ - - -#################### -# - Units -#################### -femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs') -femtosecond.set_global_relative_scale_factor(spu.femto, spu.second) - -# Length -femtometer = fm = spu.Quantity('femtometer', abbrev='fm') -femtometer.set_global_relative_scale_factor(spu.femto, spu.meter) - -# Lum Flux -lumen = lm = spu.Quantity('lumen', abbrev='lm') -lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian) - -# Force -nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816 -nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton) - -micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816 -micronewton.set_global_relative_scale_factor(spu.micro, spu.newton) - -millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816 -micronewton.set_global_relative_scale_factor(spu.milli, spu.newton) - -# Frequency -kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz') -kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) - -megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz') -kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) - -gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz') -gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz) - -terahertz = THz = spu.Quantity('terahertz', abbrev='THz') -terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz) - -petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz') -petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) - -exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') -exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) - -# Pressure -millibar = mbar = spu.Quantity('millibar', abbrev='mbar') -millibar.set_global_relative_scale_factor(spu.milli, spu.bar) - -hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 -hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) - -UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { - unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) -} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} - -UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()} - - -#################### -# - Expr Analysis: Units -#################### -## TODO: Caching w/srepr'ed expression. -## TODO: An LFU cache could do better than an LRU. -def uses_units(sp_obj: SympyType) -> bool: - """Determines if an expression uses any units. - - Notes: - The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. - Depth-first was chosen since `sp.Quantity`s are likelier to be found among individual symbols, rather than complete subexpressions. - - The **worst-case** runtime is when there are no units, in which case the **entire expression graph will be traversed**. - - Parameters: - expr: The sympy expression that may contain units. - - Returns: - Whether or not there are units used within the expression. - """ - return sp_obj.has(spu.Quantity) - # return any( - # isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj) - # ) - - -## TODO: Caching w/srepr'ed expression. -## TODO: An LFU cache could do better than an LRU. -def get_units(expr: sp.Expr) -> set[spu.Quantity]: - """Finds all units used by the expression, and returns them as a set. - - No information about _the relationship between units_ is exposed. - For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`. - - - Notes: - The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. - - The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node. - - Parameters: - expr: The sympy expression that may contain units. - - Returns: - All units (`spu.Quantity`) used within the expression. - """ - return { - subexpr - for subexpr in sp.postorder_traversal(expr) - if isinstance(subexpr, spu.Quantity) - } - - -def parse_shape(sp_obj: SympyType) -> int | None: - if isinstance(sp_obj, sp.MatrixBase): - return sp_obj.shape - - return None - - -#################### -# - Pydantic-Validated SympyExpr -#################### -class _SympyExpr: - """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`. - - Notes: - You probably want to use `SympyExpr`. - - Examples: - To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`: - - ```python - SympyExpr = typx.Annotated[SympyType, _SympyExpr] - - class Spam(pyd.BaseModel): - line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3) - ``` - """ - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: SympyType, - _handler: pyd.GetCoreSchemaHandler, - ) -> pyd_core_schema.CoreSchema: - """Compute a schema that allows `pydantic` to validate a `sympy` type.""" - - def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any: - """Parse and validate a string expression. - - Parameters: - sp_str: A stringified `sympy` object, that will be parsed to a sympy type. - Before use, `isinstance(expr_str, str)` is checked. - If the object isn't a string, then the validation will be skipped. - - Returns: - Either a `sympy` object, if the input is parseable, or the same untouched object. - - Raises: - ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. - """ - # Constrain to String - if not isinstance(sp_str, str): - return sp_str - - # Parse String -> Sympy - try: - expr = sp.sympify(sp_str) - except ValueError as ex: - msg = f'String {sp_str} is not a valid sympy expression' - raise ValueError(msg) from ex - - # Substitute Symbol -> Quantity - return expr.subs(UNIT_BY_SYMBOL) - - def validate_from_pytype( - sp_pytype: int | Fraction | float | complex, - ) -> SympyType | typ.Any: - """Parse and validate a pure Python type. - - Parameters: - sp_str: A stringified `sympy` object, that will be parsed to a sympy type. - Before use, `isinstance(expr_str, str)` is checked. - If the object isn't a string, then the validation will be skipped. - - Returns: - Either a `sympy` object, if the input is parseable, or the same untouched object. - - Raises: - ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. - """ - # Constrain to String - if not isinstance(sp_pytype, int | Fraction | float | complex): - return sp_pytype - - if isinstance(sp_pytype, int): - return sp.Integer(sp_pytype) - if isinstance(sp_pytype, Fraction): - return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) - if isinstance(sp_pytype, float): - return sp.Float(sp_pytype) - - # sp_pytype => Complex - return sp_pytype.real + sp.I * sp_pytype.imag - - sympy_expr_schema = pyd_core_schema.chain_schema( - [ - pyd_core_schema.no_info_plain_validator_function(validate_from_str), - pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), - pyd_core_schema.is_instance_schema(SympyType), - ] - ) - return pyd_core_schema.json_or_python_schema( - json_schema=sympy_expr_schema, - python_schema=sympy_expr_schema, - serialization=pyd_core_schema.plain_serializer_function_ser_schema( - lambda sp_obj: sp.srepr(sp_obj) - ), - ) - - -SympyExpr = typx.Annotated[ - sp.Basic, ## Treat all sympy types as sp.Basic - _SympyExpr, -] -## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. - - -def ConstrSympyExpr( # noqa: N802, PLR0913 - # Features - allow_variables: bool = True, - allow_units: bool = True, - # Structures - allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] - | None = None, - allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, - # Element Class - max_symbols: int | None = None, - allowed_symbols: set[sp.Symbol] | None = None, - allowed_units: set[spu.Quantity] | None = None, - # Shape Class - allowed_matrix_shapes: set[tuple[int, int]] | None = None, -) -> SympyType: - """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`. - - Relies on the `sympy` assumptions system. - See - - Parameters (TBD): - - Returns: - A type that represents a constrained `sympy` expression. - """ - - def validate_expr(expr: SympyType): - if not (isinstance(expr, SympyType),): - msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})" - raise ValueError(msg) - - msgs = set() - - # Validate Feature Class - if (not allow_variables) and (len(expr.free_symbols) > 0): - msgs.add( - f'allow_variables={allow_variables} does not match expression {expr}.' - ) - if (not allow_units) and uses_units(expr): - msgs.add(f'allow_units={allow_units} does not match expression {expr}.') - - # Validate Structure Class - if ( - allowed_sets - and isinstance(expr, sp.Expr) - and not any( - { - 'integer': expr.is_integer, - 'rational': expr.is_rational, - 'real': expr.is_real, - 'complex': expr.is_complex, - }[allowed_set] - for allowed_set in allowed_sets - ) - ): - msgs.add( - f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - if allowed_structures and not any( - { - 'scalar': True, - 'matrix': isinstance(expr, sp.MatrixBase), - }[allowed_set] - for allowed_set in allowed_structures - ): - msgs.add( - f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - - # Validate Element Class - if max_symbols and len(expr.free_symbols) > max_symbols: - msgs.add(f'max_symbols={max_symbols} does not match expression {expr}') - if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): - msgs.add( - f'allowed_symbols={allowed_symbols} does not match expression {expr}' - ) - if allowed_units and get_units(expr).issubset(allowed_units): - msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') - - # Validate Shape Class - if ( - allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) - ) and expr.shape not in allowed_matrix_shapes: - msgs.add( - f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' - ) - - # Error or Return - if msgs: - raise ValueError(str(msgs)) - return expr - - return typx.Annotated[ - sp.Basic, - _SympyExpr, - pyd.AfterValidator(validate_expr), - ] - - -#################### -# - Common ConstrSympyExpr -#################### -# Expression -ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_structures={'scalar'}, - allowed_sets={'integer', 'rational', 'real'}, -) -ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_structures={'scalar'}, - allowed_sets={'integer', 'rational', 'real', 'complex'}, -) - -# Symbol -IntSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer'}, - max_symbols=1, -) -RationalSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational'}, - max_symbols=1, -) -RealSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - max_symbols=1, -) -ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - max_symbols=1, -) -Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol - -# Unit -UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension - -## Technically a "unit expression", which includes compound types. -## Support for this is the reason to prefer over raw spu.Quantity. -Unit: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_structures={'scalar'}, -) - -# Number -IntNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer'}, - allowed_structures={'scalar'}, -) -RealNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'scalar'}, -) -ComplexNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - allowed_structures={'scalar'}, -) -Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber - -# Number -PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'scalar'}, -) -PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - allowed_structures={'scalar'}, -) -PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber - -# Vector -Real3DVector: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'matrix'}, - allowed_matrix_shapes={(3, 1)}, -) - - -#################### -# - Sympy Utilities: Printing -#################### -_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( - settings={ - 'abbrev': True, - } -) - - -def sp_to_str(sp_obj: SympyExpr) -> str: - """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. - - This should be used whenever a **string for UI use** is needed from a `sympy` object. - - Notes: - This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. - For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. - - Parameters: - sp_obj: The `sympy` object to convert to a string. - - Returns: - A string representing the expression for human use. - _The string is not re-encodable to the expression._ - """ - ## TODO: A bool flag property that does a lot of find/replace to make it super pretty - return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) - - -def pretty_symbol(sym: sp.Symbol) -> str: - return f'{sym.name} ∈ ' + ( - 'ℤ' - if sym.is_integer - else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?')) - ) - - -#################### -# - Unit Utilities -#################### -def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number: - """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. - - This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. - - Notes: - The unitless output is still an `sp.Expr`, which may contain ex. symbols. - - If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. - In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. - - Parameters: - expr: The unit-containing expression to convert. - unit_to: The unit that is converted to. - - Returns: - The unitless part of `expr`, after scaling the entire expression to `unit`. - - Raises: - ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. - """ - unitless_expr = spu.convert_to(sp_obj, unit) / unit if unit is not None else sp_obj - if not uses_units(unitless_expr): - return unitless_expr - - msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' - raise ValueError(msg) - - -def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number: - """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. - - Parameters: - unit_from: The unit that is converted from. - unit_to: The unit that is converted to. - - Returns: - The numerical scaling factor between the two units. - - Raises: - ValueError: If the two units don't share a common dimension. - """ - if unit_from.dimension == unit_to.dimension: - return scale_to_unit(unit_from, unit_to) - - msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them" - raise ValueError(msg) - - -@functools.cache -def unit_str_to_unit(unit_str: str) -> Unit | None: - # Edge Case: Manually Parse Degrees - ## -> sp.sympify('degree') actually produces the sp.degree() function. - ## -> Therefore, we must special case this particular unit. - if unit_str == 'degree': - expr = spu.degree - else: - expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL) - - if expr.has(spu.Quantity): - return expr - - msg = f'No valid unit for unit string {unit_str}' - raise ValueError(msg) - - -#################### -# - "Physical" Type -#################### -def unit_dim_to_unit_dim_deps( - unit_dims: SympyType, -) -> dict[spu.dimensions.Dimension, int] | None: - dimsys_SI = spu.systems.si.dimsys_SI - - # Retrieve Dimensional Dependencies - try: - return dimsys_SI.get_dimensional_dependencies(unit_dims) - - # Catch TypeError - ## -> Happens if `+` or `-` is in `unit`. - ## -> Generally, it doesn't make sense to add/subtract differing unit dims. - ## -> Thus, when trying to figure out the unit dimension, there isn't one. - except TypeError: - return None - - -def unit_to_unit_dim_deps( - unit: SympyType, -) -> dict[spu.dimensions.Dimension, int] | None: - # Retrieve Dimensional Dependencies - ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms. - ## -> This is extremely problematic; `Dims` arithmetic has key properties. - ## -> So we have to go all the way to the dimensional dependencies. - ## -> This isn't really respecting the args, but it seems to work :) - return unit_dim_to_unit_dim_deps( - unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)}) - ) - - -def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool: - return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps( - unit_dim_r - ) - - -def compare_unit_dim_to_unit_dim_deps( - unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int] -) -> bool: - return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps - - -class PhysicalType(enum.StrEnum): - """Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type.""" - - # Unitless - NonPhysical = enum.auto() - - # Global - Time = enum.auto() - Angle = enum.auto() - SolidAngle = enum.auto() - ## TODO: Some kind of 3D-specific orientation ex. a quaternion - Freq = enum.auto() - AngFreq = enum.auto() ## rad*hertz - # Cartesian - Length = enum.auto() - Area = enum.auto() - Volume = enum.auto() - # Mechanical - Vel = enum.auto() - Accel = enum.auto() - Mass = enum.auto() - Force = enum.auto() - Pressure = enum.auto() - # Energy - Work = enum.auto() ## joule - Power = enum.auto() ## watt - PowerFlux = enum.auto() ## watt - Temp = enum.auto() - # Electrodynamics - Current = enum.auto() ## ampere - CurrentDensity = enum.auto() - Charge = enum.auto() ## coulomb - Voltage = enum.auto() - Capacitance = enum.auto() ## farad - Impedance = enum.auto() ## ohm - Conductance = enum.auto() ## siemens - Conductivity = enum.auto() ## siemens / length - MFlux = enum.auto() ## weber - MFluxDensity = enum.auto() ## tesla - Inductance = enum.auto() ## henry - EField = enum.auto() - HField = enum.auto() - # Luminal - LumIntensity = enum.auto() - LumFlux = enum.auto() - Illuminance = enum.auto() - - @functools.cached_property - def unit_dim(self) -> SympyType: - PT = PhysicalType - return { - PT.NonPhysical: None, - # Global - PT.Time: Dims.time, - PT.Angle: Dims.angle, - PT.SolidAngle: spu.steradian.dimension, ## MISSING - PT.Freq: Dims.frequency, - PT.AngFreq: Dims.angle * Dims.frequency, - # Cartesian - PT.Length: Dims.length, - PT.Area: Dims.length**2, - PT.Volume: Dims.length**3, - # Mechanical - PT.Vel: Dims.length / Dims.time, - PT.Accel: Dims.length / Dims.time**2, - PT.Mass: Dims.mass, - PT.Force: Dims.force, - PT.Pressure: Dims.pressure, - # Energy - PT.Work: Dims.energy, - PT.Power: Dims.power, - PT.PowerFlux: Dims.power / Dims.length**2, - PT.Temp: Dims.temperature, - # Electrodynamics - PT.Current: Dims.current, - PT.CurrentDensity: Dims.current / Dims.length**2, - PT.Charge: Dims.charge, - PT.Voltage: Dims.voltage, - PT.Capacitance: Dims.capacitance, - PT.Impedance: Dims.impedance, - PT.Conductance: Dims.conductance, - PT.Conductivity: Dims.conductance / Dims.length, - PT.MFlux: Dims.magnetic_flux, - PT.MFluxDensity: Dims.magnetic_density, - PT.Inductance: Dims.inductance, - PT.EField: Dims.voltage / Dims.length, - PT.HField: Dims.current / Dims.length, - # Luminal - PT.LumIntensity: Dims.luminous_intensity, - PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, - PT.Illuminance: Dims.luminous_intensity / Dims.length**2, - }[self] - - @staticproperty - def unit_dims() -> dict[typ.Self, SympyType]: - return { - physical_type: physical_type.unit_dim - for physical_type in list(PhysicalType) - } - - @functools.cached_property - def color(self): - """A color corresponding to the physical type. - - The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented. - The LLM provided the following rationale for its choices: - - > Non-Physical: Grey signifies neutrality and non-physical nature. - > Global: - > Time: Blue is often associated with calmness and the passage of time. - > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects. - > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time. - > Cartesian: - > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension. - > Mechanical: - > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities. - > Mass: Dark red for the fundamental property. - > Force and Pressure: Shades of red indicating intensity. - > Energy: - > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities. - > Temperature: Yellow for heat. - > Electrodynamics: - > Current and related quantities: Cyan shades indicating flow. - > Voltage, Capacitance: Greenish and blueish cyan for electrical potential. - > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance. - > Magnetic properties: Magenta shades for magnetism. - > Electric Field: Light blue. - > Magnetic Field: Grey, as it can be considered neutral in terms of direction. - > Luminal: - > Luminous properties: Yellows to signify light and illumination. - > - > This color mapping helps maintain intuitive connections for users interacting with these physical types. - """ - PT = PhysicalType - return { - PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical - # Global - PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time - PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle - PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle - PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency - PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency - # Cartesian - PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length - PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area - PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume - # Mechanical - PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity - PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration - PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass - PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force - PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure - # Energy - PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work - PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power - PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux - PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature - # Electrodynamics - PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current - PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density - PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge - PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage - PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance - PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance - PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance - PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity - PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux - PT.MFluxDensity: ( - 0.85, - 0.5, - 0.85, - 1.0, - ), # Light Magenta: Magnetic Flux Density - PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance - PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field - PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field - # Luminal - PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity - PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux - PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance - }[self] - - @functools.cached_property - def default_unit(self) -> list[Unit]: - PT = PhysicalType - return { - PT.NonPhysical: None, - # Global - PT.Time: spu.picosecond, - PT.Angle: spu.radian, - PT.SolidAngle: spu.steradian, - PT.Freq: terahertz, - PT.AngFreq: spu.radian * terahertz, - # Cartesian - PT.Length: spu.micrometer, - PT.Area: spu.um**2, - PT.Volume: spu.um**3, - # Mechanical - PT.Vel: spu.um / spu.second, - PT.Accel: spu.um / spu.second, - PT.Mass: spu.microgram, - PT.Force: micronewton, - PT.Pressure: millibar, - # Energy - PT.Work: spu.joule, - PT.Power: spu.watt, - PT.PowerFlux: spu.watt / spu.meter**2, - PT.Temp: spu.kelvin, - # Electrodynamics - PT.Current: spu.ampere, - PT.CurrentDensity: spu.ampere / spu.meter**2, - PT.Charge: spu.coulomb, - PT.Voltage: spu.volt, - PT.Capacitance: spu.farad, - PT.Impedance: spu.ohm, - PT.Conductance: spu.siemens, - PT.Conductivity: spu.siemens / spu.micrometer, - PT.MFlux: spu.weber, - PT.MFluxDensity: spu.tesla, - PT.Inductance: spu.henry, - PT.EField: spu.volt / spu.micrometer, - PT.HField: spu.ampere / spu.micrometer, - # Luminal - PT.LumIntensity: spu.candela, - PT.LumFlux: spu.candela * spu.steradian, - PT.Illuminance: spu.candela / spu.meter**2, - }[self] - - @functools.cached_property - def valid_units(self) -> list[Unit]: - """Retrieve an ordered (by subjective usefulness) list of units for this physical type. - - Notes: - The order in which valid units are declared is the exact same order that UI dropdowns display them. - - **Altering the order of units breaks backwards compatibility**. - """ - PT = PhysicalType - return { - PT.NonPhysical: [None], - # Global - PT.Time: [ - spu.picosecond, - femtosecond, - spu.nanosecond, - spu.microsecond, - spu.millisecond, - spu.second, - spu.minute, - spu.hour, - spu.day, - ], - PT.Angle: [ - spu.radian, - spu.degree, - ], - PT.SolidAngle: [ - spu.steradian, - ], - PT.Freq: ( - _valid_freqs := [ - terahertz, - spu.hertz, - kilohertz, - megahertz, - gigahertz, - petahertz, - exahertz, - ] - ), - PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], - # Cartesian - PT.Length: ( - _valid_lens := [ - spu.micrometer, - spu.nanometer, - spu.picometer, - spu.angstrom, - spu.millimeter, - spu.centimeter, - spu.meter, - spu.inch, - spu.foot, - spu.yard, - spu.mile, - ] - ), - PT.Area: [_unit**2 for _unit in _valid_lens], - PT.Volume: [_unit**3 for _unit in _valid_lens], - # Mechanical - PT.Vel: [_unit / spu.second for _unit in _valid_lens], - PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], - PT.Mass: [ - spu.kilogram, - spu.electron_rest_mass, - spu.dalton, - spu.microgram, - spu.milligram, - spu.gram, - spu.metric_ton, - ], - PT.Force: [ - micronewton, - nanonewton, - millinewton, - spu.newton, - spu.kg * spu.meter / spu.second**2, - ], - PT.Pressure: [ - spu.bar, - millibar, - spu.pascal, - hectopascal, - spu.atmosphere, - spu.psi, - spu.mmHg, - spu.torr, - ], - # Energy - PT.Work: [ - spu.joule, - spu.electronvolt, - ], - PT.Power: [ - spu.watt, - ], - PT.PowerFlux: [ - spu.watt / spu.meter**2, - ], - PT.Temp: [ - spu.kelvin, - ], - # Electrodynamics - PT.Current: [ - spu.ampere, - ], - PT.CurrentDensity: [ - spu.ampere / spu.meter**2, - ], - PT.Charge: [ - spu.coulomb, - ], - PT.Voltage: [ - spu.volt, - ], - PT.Capacitance: [ - spu.farad, - ], - PT.Impedance: [ - spu.ohm, - ], - PT.Conductance: [ - spu.siemens, - ], - PT.Conductivity: [ - spu.siemens / spu.micrometer, - spu.siemens / spu.meter, - ], - PT.MFlux: [ - spu.weber, - ], - PT.MFluxDensity: [ - spu.tesla, - ], - PT.Inductance: [ - spu.henry, - ], - PT.EField: [ - spu.volt / spu.micrometer, - spu.volt / spu.meter, - ], - PT.HField: [ - spu.ampere / spu.micrometer, - spu.ampere / spu.meter, - ], - # Luminal - PT.LumIntensity: [ - spu.candela, - ], - PT.LumFlux: [ - spu.candela * spu.steradian, - ], - PT.Illuminance: [ - spu.candela / spu.meter**2, - ], - }[self] - - @staticmethod - def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None: - """Attempt to determine a matching `PhysicalType` from a unit. - - NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`. - - Returns: - The matched `PhysicalType`. - - If none could be matched, then either return `None` (if `optional` is set) or error. - - Raises: - ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. - """ - if unit is None: - return PhysicalType.NonPhysical - - ## TODO_ This enough? - if unit in [spu.radian, spu.degree]: - return PhysicalType.Angle - - unit_dim_deps = unit_to_unit_dim_deps(unit) - if unit_dim_deps is not None: - for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): - if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps): - return physical_type - - if optional: - return None - msg = f'Could not determine PhysicalType for {unit}' - raise ValueError(msg) - - @staticmethod - def from_unit_dim( - unit_dim: SympyType | None, optional: bool = False - ) -> typ.Self | None: - """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`. - - For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`). - To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions. - - Returns: - The matched `PhysicalType`. - - If none could be matched, then either return `None` (if `optional` is set) or error. - - Raises: - ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. - """ - for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): - if compare_unit_dims(unit_dim, candidate_unit_dim): - return physical_type - - if optional: - return None - msg = f'Could not determine PhysicalType for {unit_dim}' - raise ValueError(msg) - - @functools.cached_property - def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]: - PT = PhysicalType - overrides = { - # Cartesian - PT.Length: [None, (2,), (3,)], - # Mechanical - PT.Vel: [None, (2,), (3,)], - PT.Accel: [None, (2,), (3,)], - PT.Force: [None, (2,), (3,)], - # Energy - PT.Work: [None, (2,), (3,)], - PT.PowerFlux: [None, (2,), (3,)], - # Electrodynamics - PT.CurrentDensity: [None, (2,), (3,)], - PT.MFluxDensity: [None, (2,), (3,)], - PT.EField: [None, (2,), (3,)], - PT.HField: [None, (2,), (3,)], - # Luminal - PT.LumFlux: [None, (2,), (3,)], - } - - return overrides.get(self, [None]) - - @functools.cached_property - def valid_mathtypes(self) -> list[MathType]: - """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. - - Generally, all unit quantities are real, in the algebraic mathematical sense. - However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. - This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. - In general, the value is a phasor. - - While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. - - Notes: - - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. - - **Current**/**Voltage**: The imaginary part represents the phase. - This also holds for any downstream units. - - **Charge**: Generally, it is real. - However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: - - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. - - """ - MT = MathType - PT = PhysicalType - overrides = { - PT.NonPhysical: list(MT), ## Support All - # Cartesian - PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping - PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping - # Mechanical - # Energy - # Electrodynamics - PT.Current: [MT.Real, MT.Complex], ## Im -> Phase - PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase - PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase - PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase - PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase - PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance - PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction - PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction - PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction - PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase - PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase - PT.EField: [MT.Real, MT.Complex], ## Im -> Phase - PT.HField: [MT.Real, MT.Complex], ## Im -> Phase - # Luminal - } - - return overrides.get(self, [MT.Real]) - - @staticmethod - def to_name(value: typ.Self) -> str: - if value is PhysicalType.NonPhysical: - return 'Unitless' - return PhysicalType(value).name - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - PT = PhysicalType - return ( - str(self), - PT.to_name(self), - PT.to_name(self), - PT.to_icon(self), - i, - ) - - -#################### -# - Standard Unit Systems -#################### -UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit] - -_PT = PhysicalType -UNITS_SI: UnitSystem = { - _PT.NonPhysical: None, - # Global - _PT.Time: spu.second, - _PT.Angle: spu.radian, - _PT.SolidAngle: spu.steradian, - _PT.Freq: spu.hertz, - _PT.AngFreq: spu.radian * spu.hertz, - # Cartesian - _PT.Length: spu.meter, - _PT.Area: spu.meter**2, - _PT.Volume: spu.meter**3, - # Mechanical - _PT.Vel: spu.meter / spu.second, - _PT.Accel: spu.meter / spu.second**2, - _PT.Mass: spu.kilogram, - _PT.Force: spu.newton, - # Energy - _PT.Work: spu.joule, - _PT.Power: spu.watt, - _PT.PowerFlux: spu.watt / spu.meter**2, - _PT.Temp: spu.kelvin, - # Electrodynamics - _PT.Current: spu.ampere, - _PT.CurrentDensity: spu.ampere / spu.meter**2, - _PT.Voltage: spu.volt, - _PT.Capacitance: spu.farad, - _PT.Impedance: spu.ohm, - _PT.Conductance: spu.siemens, - _PT.Conductivity: spu.siemens / spu.meter, - _PT.MFlux: spu.weber, - _PT.MFluxDensity: spu.tesla, - _PT.Inductance: spu.henry, - _PT.EField: spu.volt / spu.meter, - _PT.HField: spu.ampere / spu.meter, - # Luminal - _PT.LumIntensity: spu.candela, - _PT.LumFlux: lumen, - _PT.Illuminance: spu.lux, -} - - -#################### -# - Sympy Utilities: Cast to Python -#################### -def sympy_to_python( - scalar: sp.Basic, use_jax_array: bool = False -) -> int | float | complex | tuple | jax.Array: - """Convert a scalar sympy expression to the directly corresponding Python type. - - Arguments: - scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. - For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. - - Returns: - A pure Python type that directly corresponds to the input scalar expression. - """ - if isinstance(scalar, sp.MatrixBase): - # Detect Single Column Vector - ## --> Flatten to Single Row Vector - if len(scalar.shape) == 2 and scalar.shape[1] == 1: - _scalar = scalar.T - else: - _scalar = scalar - - # Convert to Tuple of Tuples - matrix = tuple( - [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] - ) - - # Detect Single Row Vector - ## --> This could be because the scalar had it. - ## --> This could also be because we flattened a column vector. - ## Either way, we should strip the pointless dimensions. - if len(matrix) == 1: - return matrix[0] if not use_jax_array else jnp.array(matrix[0]) - - return matrix if not use_jax_array else jnp.array(matrix) - if scalar.is_integer: - return int(scalar) - if scalar.is_rational or scalar.is_real: - return float(scalar) - if scalar.is_complex: - return complex(scalar) - - msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 - raise ValueError(msg) - - -#################### -# - Convert to Unit System -#################### -def strip_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None = None -) -> SympyExpr: - """Strip units occurring in the given unit system from the expression. - - Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". - Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_. - - Notes: - You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. - """ - if unit_system is None: - return sp_obj.subs(UNIT_TO_1) - return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) - - -def convert_to_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None -) -> SympyExpr: - """Convert an expression to the units of a given unit system.""" - if unit_system is None: - return sp_obj - - return spu.convert_to( - sp_obj, - {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, - ) - - -def scale_to_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False -) -> int | float | complex | tuple | jax.Array: - """Convert an expression to the units of a given unit system, then strip all units of the unit system. - - Afterwards, it is converted to an appropriate Python type. - - Notes: - For stability, and performance, reasons, this should only be used at the very last stage. - - Regarding performance: **This is not a fast function**. - - Parameters: - sp_obj: An arbitrary sympy object, presumably with units. - unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. - Note that, in this context, only `unit_system.values()` is used. - - Returns: - An appropriate pure Python type, after scaling to the unit system and stripping all units away. - - If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. - """ - return sympy_to_python( - strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system), - use_jax_array=use_jax_array, - ) diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 6ca4412..68820c3 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -30,7 +30,7 @@ import matplotlib.figure import seaborn as sns from blender_maxwell import contracts as ct -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger sns.set_theme() diff --git a/src/blender_maxwell/utils/sci_constants.py b/src/blender_maxwell/utils/sci_constants.py index 920bb2d..935e24a 100644 --- a/src/blender_maxwell/utils/sci_constants.py +++ b/src/blender_maxwell/utils/sci_constants.py @@ -32,7 +32,7 @@ import scipy as sc import sympy as sp import sympy.physics.units as spu -from . import extra_sympy_units as spux +from . import sympy_extra as spux SUPPORTED_SCIPY_PREFIX = '1.12' if not sc.version.full_version.startswith(SUPPORTED_SCIPY_PREFIX): diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py index bab40a0..b58aebc 100644 --- a/src/blender_maxwell/utils/serialize.py +++ b/src/blender_maxwell/utils/serialize.py @@ -31,8 +31,8 @@ import uuid import msgspec import sympy as sp -from . import extra_sympy_units as spux from . import logger +from . import sympy_extra as spux log = logger.get(__name__) diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py index d3e4366..f663d36 100644 --- a/src/blender_maxwell/utils/sim_symbols.py +++ b/src/blender_maxwell/utils/sim_symbols.py @@ -25,8 +25,8 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from . import extra_sympy_units as spux from . import logger, serialize +from . import sympy_extra as spux int_min = -(2**64) int_max = 2**64 @@ -101,6 +101,12 @@ class SimSymbolName(enum.StrEnum): BlochY = enum.auto() BlochZ = enum.auto() + # New Backwards Compatible Entries + ## -> Ordered lists carry a particular enum integer index. + ## -> Therefore, anything but adding an index breaks backwards compat. + ## -> ...With all previous files. + ConstantRange = enum.auto() + #################### # - UI #################### @@ -143,7 +149,8 @@ class SimSymbolName(enum.StrEnum): } | { # Generic - SSN.Constant: 'constant', + SSN.Constant: 'cst', + SSN.ConstantRange: 'cst_range', SSN.Expr: 'expr', SSN.Data: 'data', # Greek Letters @@ -210,24 +217,43 @@ def mk_interval( interval_finite: tuple[int | Fraction | float, int | Fraction | float], interval_inf: tuple[bool, bool], interval_closed: tuple[bool, bool], - unit_factor: typ.Literal[1] | spux.Unit, ) -> sp.Interval: """Create a symbolic interval from the tuples (and unit) defining it.""" return sp.Interval( - start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo), - end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo), + start=(interval_finite[0] if not interval_inf[0] else -sp.oo), + end=(interval_finite[1] if not interval_inf[1] else sp.oo), left_open=(True if interval_inf[0] else not interval_closed[0]), right_open=(True if interval_inf[1] else not interval_closed[1]), ) class SimSymbol(pyd.BaseModel): - """A declarative representation of a symbolic variable. + """A convenient, constrained representation of a symbolic variable suitable for many tasks. - `sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof. + The original motivation was to enhance `sp.Symbol` with greater flexibility, semantic context, and a UI-friendly representation. + Today, `SimSymbol` is a fully capable primitive for defining the interfaces between externally tracked mathematical elements, and planning the required operations between them. + + A symbol represented as `SimSymbol` carries all the semantic meaning of that symbol, and comes with a comprehensive library of useful (computed) properties and methods. + It is immutable, hashable, and serializable, and as a `pydantic.BaseModel` with aggressive property caching, its performance properties should also be well-suited for use in the hot-loops of ex. UI draw methods. + + Attributes: + sym_name: For humans and computers, symbol names induces a lot of implicit semantics. + mathtype: Symbols are associated with some set of valid values. + We choose to constrain `SimSymbol` to only associate with _mathematical_ (aka. number-like) sets. + This prohibits ex. booleans and predicate-logic applications, but eases a lot of burdens associated with actually using `SimSymbol`. + physical_type: Symbols may be associated with a particular unit dimension expression. + This allows the symbol to have _physical meaning_. + This information is **generally not** encoded in auxiliary attributes like `self.domain`, but **generally is** encoded by computed properties/methods. + unit: Symbols may be associated with a particular unit, which must be compatible with the `PhysicalType`. + **NOTE**: Unit expressions may well have physical meaning, without being strictly conformable to a pre-blessed `PhysicalType`s. + We do try to avoid such cases, but for the sake of correctness, our chosen convention is to let `self.physical_type` be "`NonPhysical`", while still allowing a unit. + size: Symbols may themselves have shape. + **NOTE**: We deliberately choose to constrain `SimSymbol`s to two dimensions, allowing them to represent scalars, vectors, covectors, and matrices, but **not** arbitrary tensors. + This is a practical tradeoff, made both to make it easier (in terms of mathematical analysis) to implement `SimSymbol`, but also to make it easier to define UI elements that drive / are driven by `SimSymbol`s. + domain: Symbols are associated with a _domain of valid values_, expressed with any mathematical set implemented as a subclass of `sympy.Set`. + By using a true symbolic set, we gain unbounded flexibility in how to define the validity of a set, including an extremely capable `* in self.domain` operator encapsulating a lot of otherwise very manual logic. + **NOTE** that `self.unit` is **not** baked into the domain, due to practicalities associated with subclasses of `sp.Set`. - This dataclass is UI-friendly, as it only uses field type annotations/defaults supported by `bl_cache.BLProp`. - It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. """ model_config = pyd.ConfigDict(frozen=True) @@ -238,11 +264,8 @@ class SimSymbol(pyd.BaseModel): # Units ## -> 'None' indicates that no particular unit has yet been chosen. - ## -> Not exposed in the UI; must be set some other way. + ## -> When 'self.physical_type' is NonPhysical, can only be None. unit: spux.Unit | None = None - ## -> TODO: We currently allowing units that don't match PhysicalType - ## -> -- In particular, NonPhysical w/units means "unknown units". - ## -> -- This is essential for the Scientific Constant Node. # Size ## -> All SimSymbol sizes are "2D", but interpreted by convention. @@ -253,39 +276,96 @@ class SimSymbol(pyd.BaseModel): rows: int = 1 cols: int = 1 - # Scalar Domain: "Interval" - ## -> NOTE: interval_finite_*[0] must be strictly smaller than [1]. - ## -> See self.domain. - ## -> We have to deconstruct symbolic interval semantics a bit for UI. - is_constant: bool = False - exclude_zero: bool = False + # Valid Domain + ## -> Declares the valid set of values that may be given to this symbol. + ## -> By convention, units are not encoded in the domain sp.Set. + ## -> 'sp.Set's are extremely expressive and cool. + domain: spux.SympyExpr | None = None - interval_finite_z: tuple[int, int] = (0, 1) - interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1)) - interval_finite_re: tuple[float, float] = (0.0, 1.0) - interval_inf: tuple[bool, bool] = (True, True) - interval_closed: tuple[bool, bool] = (False, False) + @functools.cached_property + def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet: + if self.rows > 1 or self.cols > 1: + return sp.matrices.MatrixSet(self.rows, self.cols, self.domain) + return self.domain - interval_finite_im: tuple[float, float] = (0.0, 1.0) - interval_inf_im: tuple[bool, bool] = (True, True) - interval_closed_im: tuple[bool, bool] = (False, False) - - preview_value_z: int = 0 - preview_value_q: tuple[int, int] = (0, 1) - preview_value_re: float = 0.0 - preview_value_im: float = 0.0 + preview_value: spux.SympyExpr | None = None #################### - # - Core + # - Validators + #################### + ## TODO: Check domain against MathType + ## -- Surprisingly hard without a lot of special-casing. + + ## TODO: Check that size is valid for the PhysicalType. + + ## TODO: Check that constant value (domain=FiniteSet(cst)) is compatible with the MathType. + + ## TODO: Check that preview_value is in the domain. + + @pyd.model_validator(mode='after') + def set_undefined_domain_from_mathtype(self) -> typ.Self: + """When the domain is not set, then set it using the symbolic set of the MathType.""" + if self.domain is None: + object.__setattr__(self, 'domain', self.mathtype.symbolic_set) + return self + + @pyd.model_validator(mode='after') + def conform_undefined_preview_value_to_constant(self) -> typ.Self: + """When the `SimSymbol` is a constant, but the preview value is not set, then set the preview value from the constant.""" + if self.is_constant and not self.preview_value: + object.__setattr__(self, 'preview_value', self.constant_value) + return self + + @pyd.model_validator(mode='after') + def conform_preview_value(self) -> typ.Self: + """Conform the given preview value to the `SimSymbol`.""" + if self.is_constant and not self.preview_value: + object.__setattr__( + self, + 'preview_value', + self.conform(self.preview_value, strip_units=True), + ) + return self + + #################### + # - Domain #################### @functools.cached_property - def name(self) -> str: - """Usable name for the symbol.""" - return self.sym_name.name + def is_constant(self) -> bool: + """When the symbol domain is a single-element `sp.FiniteSet`, then the symbol can be considered to be a constant.""" + return isinstance(self.domain, sp.FiniteSet) and len(self.domain) == 1 + + @functools.cached_property + def constant_value(self) -> bool: + """Get the constant when `is_constant` is True. + + The `self.unit_factor` is multiplied onto the constant at this point. + """ + if self.is_constant: + return next(iter(self.domain)) * self.unit_factor + + msg = 'Tried to get constant value of non-constant SimSymbol.' + raise ValueError(msg) + + @functools.cached_property + def is_nonzero(self) -> bool: + """Whether $0$ is a valid value for this symbol. + + When shaped, $0$ refers to the relevant shaped object with all elements $0$. + + Notes: + Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`. + """ + return 0 in self.domain #################### # - Labels #################### + @functools.cached_property + def name(self) -> str: + """Usable string name for the symbol.""" + return self.sym_name.name + @functools.cached_property def name_pretty(self) -> str: """Pretty (possibly unicode) name for the thing.""" @@ -340,7 +420,8 @@ class SimSymbol(pyd.BaseModel): return self.unit if self.unit is not None else sp.S(1) @functools.cached_property - def size(self) -> tuple[int, ...] | None: + def size(self) -> spux.NumberSize1D | None: + """The 1D number size of this `SimSymbol`, if it has one; else None.""" return { (1, 1): spux.NumberSize1D.Scalar, (2, 1): spux.NumberSize1D.Vec2, @@ -350,13 +431,17 @@ class SimSymbol(pyd.BaseModel): @functools.cached_property def shape(self) -> tuple[int, ...]: + """Deterministic chosen shape of this `SimSymbol`. + + Derived from `self.rows` and `self.cols`. + + Is never `None`; instead, empty tuple `()` is used. + """ match (self.rows, self.cols): case (1, 1): return () case (_, 1): return (self.rows,) - case (1, _): - return (1, self.rows) case (_, _): return (self.rows, self.cols) @@ -365,116 +450,6 @@ class SimSymbol(pyd.BaseModel): """Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking.""" return len(self.shape) - @functools.cached_property - def domain(self) -> sp.Interval | sp.Set: - """Return the scalar domain of valid values for each element of the symbol. - - For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties. - This interval **must** have the property`start <= stop`. - - Otherwise, the domain is the symbolic set corresponding to `self.mathtype`. - """ - match self.mathtype: - case spux.MathType.Integer: - return mk_interval( - self.interval_finite_z, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Rational: - return mk_interval( - Fraction(*self.interval_finite_q), - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Real: - return mk_interval( - self.interval_finite_re, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Complex: - return ( - mk_interval( - self.interval_finite_re, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ), - mk_interval( - self.interval_finite_im, - self.interval_inf_im, - self.interval_closed_im, - self.unit_factor, - ), - ) - - @functools.cached_property - def valid_domain_value(self) -> spux.SympyExpr: - """A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`.""" - match (self.domain.start.is_finite, self.domain.end.is_finite): - case (True, True): - if self.mathtype is spux.MathType.Integer: - return (self.domain.start + self.domain.end) // 2 - return (self.domain.start + self.domain.end) / 2 - - case (True, False): - one = sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - return self.domain.start + one - - case (False, True): - one = sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - return self.domain.end - one - - case (False, False): - return sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - - @functools.cached_property - def is_nonzero(self) -> bool: - """Whether or not the value of this symbol can ever be $0$. - - Notes: - Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`. - """ - if self.exclude_zero: - return True - - def check_real_domain(real_domain): - return ( - ( - real_domain.left == 0 - and real_domain.left_open - or real_domain.right == 0 - and real_domain.right_open - ) - or real_domain.left > 0 - or real_domain.right < 0 - ) - - if self.mathtype is spux.MathType.Complex: - return check_real_domain(self.domain[0]) and check_real_domain( - self.domain[1] - ) - return check_real_domain(self.domain) - - @functools.cached_property - def can_diff(self) -> bool: - """Whether this symbol can be used as the input / output variable when differentiating.""" - # Check Constants - ## -> Constants (w/pinned values) are never differentiable. - if self.is_constant: - return False - - # TODO: Discontinuities (especially across 0)? - - return self.mathtype in [spux.MathType.Real, spux.MathType.Complex] - #################### # - Properties #################### @@ -511,9 +486,9 @@ class SimSymbol(pyd.BaseModel): # Positive/Negative Assumption if self.mathtype is not spux.MathType.Complex: - if self.domain.left >= 0: + if self.domain.inf >= 0: mathtype_kwargs |= {'positive': True} - elif self.domain.right <= 0: + elif self.domain.sup < 0: mathtype_kwargs |= {'negative': True} # Scalar: Return Symbol @@ -571,7 +546,7 @@ class SimSymbol(pyd.BaseModel): """ if self.size is not None: if self.unit in self.physical_type.valid_units: - return { + socket_info = { 'output_name': self.sym_name, # Socket Interface 'size': self.size, @@ -580,23 +555,42 @@ class SimSymbol(pyd.BaseModel): # Defaults: Units 'default_unit': self.unit, 'default_symbols': [], - # Defaults: FlowKind.Value - 'default_value': self.conform( - self.valid_domain_value, strip_unit=True - ), - # Defaults: FlowKind.Range - 'default_min': self.conform(self.domain.start, strip_unit=True), - 'default_max': self.conform(self.domain.end, strip_unit=True), } + + # Defaults: FlowKind.Value + if self.preview_value: + socket_info |= { + 'default_value': self.conform( + self.preview_value, strip_unit=True + ) + } + + # Defaults: FlowKind.Range + if ( + self.mathtype is not spux.MathType.Complex + and self.rows == 1 + and self.cols == 1 + ): + socket_info |= { + 'default_min': self.domain.inf, + 'default_max': self.domain.sup, + } + ## TODO: Handle discontinuities / disjointness / open boundaries. + msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})' raise NotImplementedError(msg) + msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})' raise NotImplementedError(msg) #################### - # - Operations + # - Operations: Raw Update #################### def update(self, **kwargs) -> typ.Self: + """Create a new `SimSymbol`, such that the given keyword arguments override the existing values.""" + if not kwargs: + return self + def get_attr(attr: str): _notfound = 'notfound' if kwargs.get(attr, _notfound) is _notfound: @@ -610,61 +604,101 @@ class SimSymbol(pyd.BaseModel): unit=get_attr('unit'), rows=get_attr('rows'), cols=get_attr('cols'), - interval_finite_z=get_attr('interval_finite_z'), - interval_finite_q=get_attr('interval_finite_q'), - interval_finite_re=get_attr('interval_finite_re'), - interval_inf=get_attr('interval_inf'), - interval_closed=get_attr('interval_closed'), - interval_finite_im=get_attr('interval_finite_im'), - interval_inf_im=get_attr('interval_inf_im'), - interval_closed_im=get_attr('interval_closed_im'), + domain=get_attr('domain'), ) - def set_finite_domain( # noqa: PLR0913 - self, - start: int | float, - end: int | float, - start_closed: bool = True, - end_closed: bool = True, - start_im: bool = float, - end_im: bool = float, - start_closed_im: bool = True, - end_closed_im: bool = True, - ) -> typ.Self: - """Update the symbol with a finite range.""" - closed_re = (start_closed, end_closed) - closed_im = (start_closed_im, end_closed_im) - match self.mathtype: - case spux.MathType.Integer: - return self.update( - interval_finite_z=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Rational: - return self.update( - interval_finite_q=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Real: - return self.update( - interval_finite_re=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Complex: - return self.update( - interval_finite_re=(start, end), - interval_finite_im=(start_im, end_im), - interval_inf=(False, False), - interval_closed=closed_re, - interval_closed_im=closed_im, - ) + #################### + # - Operations: Comparison + #################### + def compare(self, other: typ.Self) -> typ.Self: + """Whether this SimSymbol can be considered equivalent to another, and thus universally usable in arbitrary mathematical operations together. - def set_size(self, rows: int, cols: int) -> typ.Self: - return self.update(rows=rows, cols=cols) + In particular, two attributes are ignored: + - **Name**: The particluar choices of name are not generally important. + - **Unit**: The particulars of unit equivilancy are not generally important; only that the `PhysicalType` is equal, and thus that they are compatible. + While not usable in all cases, this method ends up being very helpful for simplifying certain checks that would otherwise take up a lot of space. + """ + return ( + self.mathtype is other.mathtype + and self.physical_type is other.physical_type + and self.compare_size(other) + and self.domain == other.domain + ) + + def compare_size(self, other: typ.Self) -> typ.Self: + """Compare the size of this `SimSymbol` with another.""" + return self.rows == other.rows and self.cols == other.cols + + def compare_addable( + self, other: typ.Self, allow_differing_unit: bool = False + ) -> bool: + """Whether two `SimSymbol`s can be added.""" + common = ( + self.compare_size(other.output) + and self.physical_type is other.physical_type + and not ( + self.physical_type is spux.NonPhysical + and self.unit is not None + and self.unit != other.unit + ) + and not ( + other.physical_type is spux.NonPhysical + and other.unit is not None + and self.unit != other.unit + ) + ) + if not allow_differing_unit: + return common and self.output.unit == other.output.unit + return common + + def compare_multiplicable(self, other: typ.Self) -> bool: + """Whether two `SimSymbol`s can be multiplied.""" + return self.shape_len == 0 or self.compare_size(other) + + def compare_exponentiable(self, other: typ.Self) -> bool: + """Whether two `SimSymbol`s can be exponentiated. + + "Hadamard Power" is defined for any combination of scalar/vector/matrix operands, for any `MathType` combination. + The only important thing to check is that the exponent cannot have a physical unit. + + Sometimes, people write equations with units in the exponent. + This is a notational shorthand that only works in the context of an implicit, cancelling factor. + We reject such things. + + See https://physics.stackexchange.com/questions/109995/exponential-or-logarithm-of-a-dimensionful-quantity + """ + return ( + other.physical_type is spux.PhysicalType.NonPhysical and other.unit is None + ) + + #################### + # - Operations: Copying Setters + #################### + def set_constant(self, constant_value: spux.SympyType) -> typ.Self: + """Set the constant value of this `SimSymbol`, by setting it as the only value in a `sp.FiniteSet` domain. + + The `constant_value` will be conformed and stripped (with `self.conform()`) before being injected into the new `sp.FiniteSet` domain. + + Warnings: + Keep in mind that domains do not encode units, for practical reasons related to the diverging ways in which various `sp.Set` subclasses interpret units. + + This isn't noticeable in normal constant-symbol workflows, where the constant is retrieved using `self.constant_value` (which adds `self.unit_factor`). + However, **remember that retrieving the domain directly won't add the unit**. + + Ye been warned! + """ + if self.is_constant: + return self.update( + domain=sp.FiniteSet(self.conform(constant_value, strip_unit=True)) + ) + + msg = 'Tried to set constant value of non-constant SimSymbol.' + raise ValueError(msg) + + #################### + # - Operations: Conforming Mappers + #################### def conform( self, sp_obj: spux.SympyType, strip_unit: bool = False ) -> spux.SympyType: @@ -732,6 +766,9 @@ class SimSymbol(pyd.BaseModel): return res # noqa: RET504 + #################### + # - Creation + #################### @staticmethod def from_expr( sym_name: SimSymbolName, diff --git a/src/blender_maxwell/utils/sympy_extra/__init__.py b/src/blender_maxwell/utils/sympy_extra/__init__.py new file mode 100644 index 0000000..6ac4fca --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/__init__.py @@ -0,0 +1,173 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Declares many useful primitives to greatly simplify working with `sympy` in the context of a unit-aware system.""" + +from .math_type import MathType +from .number_size import NumberSize1D, NumberSize2D +from .parse_cast import parse_shape, pretty_symbol, sp_to_str, sympy_to_python +from .physical_type import Dims, PhysicalType +from .sympy_expr import ( + ComplexNumber, + ComplexSymbol, + ConstrSympyExpr, + IntNumber, + IntSymbol, + Number, + PhysicalComplexNumber, + PhysicalNumber, + PhysicalRealNumber, + RationalSymbol, + Real3DVector, + RealNumber, + RealSymbol, + ScalarUnitlessComplexExpr, + ScalarUnitlessRealExpr, + Symbol, + SympyExpr, + Unit, + UnitDimension, +) +from .sympy_type import SympyType +from .unit_analysis import ( + compare_unit_dims, + compare_units_by_unit_dims, + convert_to_unit, + get_units, + scale_to_unit, + scaling_factor, + strip_units, + unit_dim_to_unit_dim_deps, + unit_str_to_unit, + unit_to_unit_dim_deps, + uses_units, +) +from .unit_system_analysis import ( + convert_to_unit_system, + scale_to_unit_system, + strip_unit_system, +) +from .unit_systems import UNITS_SI, UnitSystem +from .units import ( + UNIT_BY_SYMBOL, + UNIT_TO_1, + EHz, + GHz, + KHz, + MHz, + PHz, + THz, + exahertz, + femtometer, + femtosecond, + fm, + fs, + gigahertz, + hectopascal, + hPa, + kilohertz, + lm, + lumen, + mbar, + megahertz, + micronewton, + millibar, + millinewton, + mN, + nanonewton, + nN, + petahertz, + terahertz, + uN, +) + +__all__ = [ + 'MathType', + 'NumberSize1D', + 'NumberSize2D', + 'parse_shape', + 'pretty_symbol', + 'sp_to_str', + 'sympy_to_python', + 'Dims', + 'PhysicalType', + 'ComplexNumber', + 'ComplexSymbol', + 'ConstrSympyExpr', + 'IntNumber', + 'IntSymbol', + 'Number', + 'PhysicalComplexNumber', + 'PhysicalNumber', + 'PhysicalRealNumber', + 'RationalSymbol', + 'Real3DVector', + 'RealNumber', + 'RealSymbol', + 'ScalarUnitlessComplexExpr', + 'ScalarUnitlessRealExpr', + 'Symbol', + 'SympyExpr', + 'Unit', + 'UnitDimension', + 'SympyType', + 'compare_unit_dims', + 'compare_units_by_unit_dims', + 'convert_to_unit', + 'get_units', + 'scale_to_unit', + 'scaling_factor', + 'strip_units', + 'unit_dim_to_unit_dim_deps', + 'unit_str_to_unit', + 'unit_to_unit_dim_deps', + 'uses_units', + 'strip_unit_system', + 'UNITS_SI', + 'UnitSystem', + 'convert_to_unit_system', + 'scale_to_unit_system', + 'UNIT_BY_SYMBOL', + 'UNIT_TO_1', + 'EHz', + 'GHz', + 'KHz', + 'MHz', + 'PHz', + 'THz', + 'exahertz', + 'femtometer', + 'femtosecond', + 'fm', + 'fs', + 'gigahertz', + 'hectopascal', + 'hPa', + 'kilohertz', + 'lm', + 'lumen', + 'mbar', + 'megahertz', + 'micronewton', + 'millibar', + 'millinewton', + 'mN', + 'nanonewton', + 'nN', + 'petahertz', + 'terahertz', + 'uN', +] diff --git a/src/blender_maxwell/utils/sympy_extra/math_type.py b/src/blender_maxwell/utils/sympy_extra/math_type.py new file mode 100644 index 0000000..8830d85 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/math_type.py @@ -0,0 +1,362 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Implements `MathType`, a convenient UI-friendly identifier of numerical identity.""" + +import enum +import sys +import typing as typ +from fractions import Fraction + +import jax +import jaxtyping as jtyp +import sympy as sp + +from blender_maxwell import contracts as ct + +from .. import logger +from .sympy_type import SympyType + +log = logger.get(__name__) + + +class MathType(enum.StrEnum): + """A convenient, UI-friendly identifier of a numerical object's identity.""" + + Integer = enum.auto() + Rational = enum.auto() + Real = enum.auto() + Complex = enum.auto() + + #################### + # - Checks + #################### + @staticmethod + def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'jax', 'expr'] | None: + """Determine whether an object of arbitrary type can be considered to have a `MathType`. + + - **Pure Python**: The numerical Python types (`int | Fraction | float | complex`) are all valid. + - **Expression**: Sympy types / expression are in general considered to have a valid MathType. + - **Jax**: Non-empty `jax` arrays with a valid numerical Python type as the first element are valid. + + Returns: + A string literal indicating how to parse the object for a valid `MathType`. + + If the presence of a MathType couldn't be deduced, then return None. + """ + if isinstance(obj, int | Fraction | float | complex): + return 'pytype' + + if ( + isinstance(obj, jax.Array) + and obj + and isinstance(obj.item(0), int | Fraction | float | complex) + ): + return 'jax' + + if isinstance(obj, sp.Basic | sp.MatrixBase): + return 'expr' + ## TODO: Should we check deeper? + + return None + + #################### + # - Creation + #################### + @staticmethod + def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: # noqa: PLR0911 + """Deduce the `MathType` of an arbitrary sympy object (/expression). + + The "assumptions" system of `sympy` is relied on to determine the key properties of the expression. + To this end, it's important to note several of the shortcomings of the "assumptions" system: + + - All elements, especially symbols, must have well-defined assumptions, ex. `real=True`. + - Only the "narrowest" possible `MathType` will be deduced, ex. `5` may well be the result of a complex expression, but since it is now an integer, it will parse to `MathType.Integer`. This may break some + - For infinities, only real and complex infinities are distinguished between in `sympy` (`sp.oo` vs. `sp.zoo`) - aka. there is no "integer infinity" which will parse to `Integer` with this method. + + Warnings: + Using the "assumptions" system like this requires a lot of rigor in the entire program. + + Notes: + Any matrix-like object will have `MathType.combine()` run on all of its (flattened) elements. + This is an extremely **slow** operation, but accurate, according to the semantics of `MathType.combine()`. + + Note that `sp.MatrixSymbol` _cannot have assumptions_, and thus shouldn't be used in `sp_obj`. + + Returns: + A corresponding `MathType`; else, if `optional=True`, return `None`. + + Raises: + ValueError: If no corresponding `MathType` could be determined, and `optional=False`. + + """ + if isinstance(sp_obj, sp.MatrixBase): + return MathType.combine( + *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] + ) + + if sp_obj.is_integer: + return MathType.Integer + if sp_obj.is_rational: + return MathType.Rational + if sp_obj.is_real: + return MathType.Real + if sp_obj.is_complex: + return MathType.Complex + + # Infinities + if sp_obj in [sp.oo, -sp.oo]: + return MathType.Real + if sp_obj in [sp.zoo, -sp.zoo]: + return MathType.Complex + + if optional: + return None + + msg = f"Can't determine MathType from sympy object: {sp_obj}" + raise ValueError(msg) + + @staticmethod + def from_pytype(dtype: type) -> type: + return { + int: MathType.Integer, + Fraction: MathType.Rational, + float: MathType.Real, + complex: MathType.Complex, + }[dtype] + + @staticmethod + def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type: + """Deduce the MathType corresponding to a JAX array. + + We go about this by leveraging that: + - `data` is of a homogeneous type. + - `data.item(0)` returns a single element of the array w/pure-python type. + + By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency. + + Notes: + Should also work with numpy arrays. + """ + if len(data) > 0: + return MathType.from_pytype(type(data.item(0))) + + msg = 'Cannot determine MathType from empty jax array.' + raise ValueError(msg) + + #################### + # - Operations + #################### + @staticmethod + def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None: + if MathType.Complex in mathtypes: + return MathType.Complex + if MathType.Real in mathtypes: + return MathType.Real + if MathType.Rational in mathtypes: + return MathType.Rational + if MathType.Integer in mathtypes: + return MathType.Integer + + if optional: + return None + + msg = f"Can't combine mathtypes {mathtypes}" + raise ValueError(msg) + + def is_compatible(self, other: typ.Self) -> bool: + MT = MathType + return ( + other + in { + MT.Integer: [MT.Integer], + MT.Rational: [MT.Integer, MT.Rational], + MT.Real: [MT.Integer, MT.Rational, MT.Real], + MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex], + }[self] + ) + + def coerce_compatible_pyobj( + self, pyobj: bool | int | Fraction | float | complex + ) -> int | Fraction | float | complex: + """Coerce a pure-python object of numerical type to the _exact_ type indicated by this `MathType`. + + This is needed when ex. one has an integer, but it is important that that integer be passed as a complex number. + """ + MT = MathType + match self: + case MT.Integer: + return int(pyobj) + case MT.Rational if isinstance(pyobj, int): + return Fraction(pyobj, 1) + case MT.Rational if isinstance(pyobj, Fraction): + return pyobj + case MT.Real: + return float(pyobj) + case MT.Complex if isinstance(pyobj, int | Fraction): + return complex(float(pyobj), 0) + case MT.Complex if isinstance(pyobj, float): + return complex(pyobj, 0) + + @staticmethod + def from_symbolic_set( + s: typ.Literal[ + sp.Naturals + | sp.Naturals0 + | sp.Integers + | sp.Rationals + | sp.Reals + | sp.Complexes + ] + | sp.Set, + optional: bool = False, + ) -> typ.Self | None: + """Deduce the `MathType` from a particular symbolic set. + + Currently hard-coded. + Any deviation that might be expected to work, ex. `sp.Reals - {0}`, currently won't (currently). + + Raises: + ValueError: If a non-hardcoded symbolic set is passed. + """ + MT = MathType + match s: + case sp.Naturals | sp.Naturals0 | sp.Integers: + return MT.Integer + case sp.Rationals: + return MT.Rational + case sp.Reals: + return MT.Real + case sp.Complexes: + return MT.Complex + + if optional: + return None + + msg = f"Can't deduce MathType from symbolic set {s}" + raise ValueError(msg) + + #################### + # - Casting: Pytype + #################### + @property + def pytype(self) -> type: + """Deduce the pure-Python type that corresponds to this `MathType`.""" + MT = MathType + return { + MT.Integer: int, + MT.Rational: Fraction, + MT.Real: float, + MT.Complex: complex, + }[self] + + @property + def inf_finite(self) -> type: + """Opinionated finite representation of "infinity" within this `MathType`. + + These are chosen using `sys.maxsize` and `sys.float_info`. + As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated. + + **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here. + + Notes: + Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. . + + These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`). + In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway. + + However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context. + """ + MT = MathType + Z = MT.Integer + R = MT.Integer + return { + MT.Integer: (-sys.maxsize, sys.maxsize), + MT.Rational: ( + Fraction(Z.inf_finite[0], 1), + Fraction(Z.inf_finite[1], 1), + ), + MT.Real: -(sys.float_info.min, sys.float_info.max), + MT.Complex: ( + complex(R.inf_finite[0], R.inf_finite[0]), + complex(R.inf_finite[1], R.inf_finite[1]), + ), + }[self] + + #################### + # - Casting: Symbolic + #################### + @property + def symbolic_set(self) -> sp.Set: + """Deduce the symbolic `sp.Set` type that corresponds to this `MathType`.""" + MT = MathType + return { + MT.Integer: sp.Integers, + MT.Rational: sp.Rationals, + MT.Real: sp.Reals, + MT.Complex: sp.Complexes, + }[self] + + @property + def sp_symbol_a(self) -> type: + MT = MathType + return { + MT.Integer: sp.Symbol('a', integer=True), + MT.Rational: sp.Symbol('a', rational=True), + MT.Real: sp.Symbol('a', real=True), + MT.Complex: sp.Symbol('a', complex=True), + }[self] + + #################### + # - Labels + #################### + @staticmethod + def to_str(value: typ.Self) -> type: + return { + MathType.Integer: 'ℤ', + MathType.Rational: 'ℚ', + MathType.Real: 'ℝ', + MathType.Complex: 'ℂ', + }[value] + + @property + def name(self) -> str: + """Simple non-unicode name of the math type.""" + return str(self) + + @property + def label_pretty(self) -> str: + return MathType.to_str(self) + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + return MathType.to_str(value) + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + MathType.to_name(self), + MathType.to_name(self), + MathType.to_icon(self), + i, + ) diff --git a/src/blender_maxwell/utils/sympy_extra/number_size.py b/src/blender_maxwell/utils/sympy_extra/number_size.py new file mode 100644 index 0000000..14bf34d --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/number_size.py @@ -0,0 +1,148 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import sympy as sp + +from blender_maxwell import contracts as ct + + +#################### +# - Size: 1D +#################### +class NumberSize1D(enum.StrEnum): + """Valid 1D-constrained shape.""" + + Scalar = enum.auto() + Vec2 = enum.auto() + Vec3 = enum.auto() + Vec4 = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: 'Scalar', + NS.Vec2: '2D', + NS.Vec3: '3D', + NS.Vec4: '4D', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: '', + NS.Vec2: '', + NS.Vec3: '', + NS.Vec4: '', + }[value] + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + NumberSize1D.to_name(self), + NumberSize1D.to_name(self), + NumberSize1D.to_icon(self), + i, + ) + + @staticmethod + def has_shape(shape: tuple[int, ...] | None): + return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)] + + def supports_shape(self, shape: tuple[int, ...] | None): + NS = NumberSize1D + match self: + case NS.Scalar: + return shape is None + case NS.Vec2: + return shape in ((2,), (2, 1)) + case NS.Vec3: + return shape in ((3,), (3, 1)) + case NS.Vec4: + return shape in ((4,), (4, 1)) + + @staticmethod + def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: + NS = NumberSize1D + return { + None: NS.Scalar, + (2,): NS.Vec2, + (3,): NS.Vec3, + (4,): NS.Vec4, + (2, 1): NS.Vec2, + (3, 1): NS.Vec3, + (4, 1): NS.Vec4, + }[shape] + + @property + def rows(self): + NS = NumberSize1D + return { + NS.Scalar: 1, + NS.Vec2: 2, + NS.Vec3: 3, + NS.Vec4: 4, + }[self] + + @property + def cols(self): + return 1 + + @property + def shape(self): + NS = NumberSize1D + return { + NS.Scalar: None, + NS.Vec2: (2,), + NS.Vec3: (3,), + NS.Vec4: (4,), + }[self] + + +def symbol_range(sym: sp.Symbol) -> str: + return f'{sym.name} ∈ ' + ( + 'ℂ' + if sym.is_complex + else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) + ) + + +#################### +# - Symbol Sizes +#################### +class NumberSize2D(enum.StrEnum): + """Simple subset of sizes for rank-2 tensors.""" + + Scalar = enum.auto() + + # Vectors + Vec2 = enum.auto() ## 2x1 + Vec3 = enum.auto() ## 3x1 + Vec4 = enum.auto() ## 4x1 + + # Covectors + CoVec2 = enum.auto() ## 1x2 + CoVec3 = enum.auto() ## 1x3 + CoVec4 = enum.auto() ## 1x4 + + # Square Matrices + Mat22 = enum.auto() ## 2x2 + Mat33 = enum.auto() ## 3x3 + Mat44 = enum.auto() ## 4x4 diff --git a/src/blender_maxwell/utils/sympy_extra/parse_cast.py b/src/blender_maxwell/utils/sympy_extra/parse_cast.py new file mode 100644 index 0000000..37672f9 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/parse_cast.py @@ -0,0 +1,119 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import jax +import jax.numpy as jnp +import sympy as sp + +from .. import logger +from .sympy_type import SympyType + +log = logger.get(__name__) + + +#################### +# - Parsing: Info from SympyType +#################### +def parse_shape(sp_obj: SympyType) -> int | None: + if isinstance(sp_obj, sp.MatrixBase): + return sp_obj.shape + + return None + + +#################### +# - Casting: Python +#################### +def sympy_to_python( + scalar: sp.Basic, use_jax_array: bool = False +) -> int | float | complex | tuple | jax.Array: + """Convert a scalar sympy expression to the directly corresponding Python type. + + Arguments: + scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. + For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. + + Returns: + A pure Python type that directly corresponds to the input scalar expression. + """ + if isinstance(scalar, sp.MatrixBase): + # Detect Single Column Vector + ## --> Flatten to Single Row Vector + if len(scalar.shape) == 2 and scalar.shape[1] == 1: + _scalar = scalar.T + else: + _scalar = scalar + + # Convert to Tuple of Tuples + matrix = tuple( + [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] + ) + + # Detect Single Row Vector + ## --> This could be because the scalar had it. + ## --> This could also be because we flattened a column vector. + ## Either way, we should strip the pointless dimensions. + if len(matrix) == 1: + return matrix[0] if not use_jax_array else jnp.array(matrix[0]) + + return matrix if not use_jax_array else jnp.array(matrix) + if scalar.is_integer: + return int(scalar) + if scalar.is_rational or scalar.is_real: + return float(scalar) + if scalar.is_complex: + return complex(scalar) + + msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 + raise ValueError(msg) + + +#################### +# - Casting: Printing +#################### +_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( + settings={ + 'abbrev': True, + } +) + + +def sp_to_str(sp_obj: SympyType) -> str: + """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. + + This should be used whenever a **string for UI use** is needed from a `sympy` object. + + Notes: + This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. + For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. + + Parameters: + sp_obj: The `sympy` object to convert to a string. + + Returns: + A string representing the expression for human use. + _The string is not re-encodable to the expression._ + """ + ## TODO: A bool flag property that does a lot of find/replace to make it super pretty + return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) + + +def pretty_symbol(sym: sp.Symbol) -> str: + return f'{sym.name} ∈ ' + ( + 'ℤ' + if sym.is_integer + else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?')) + ) diff --git a/src/blender_maxwell/utils/sympy_extra/physical_type.py b/src/blender_maxwell/utils/sympy_extra/physical_type.py new file mode 100644 index 0000000..2adedda --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/physical_type.py @@ -0,0 +1,644 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Implements `PhysicalType`, a convenient, UI-friendly way of deterministically handling the unit-dimensionality of arbitrary objects.""" + +import enum +import functools +import typing as typ + +import sympy.physics.units as spu + +from blender_maxwell import contracts as ct + +from ..staticproperty import staticproperty +from . import units as spux +from .math_type import MathType +from .sympy_expr import Unit +from .sympy_type import SympyType +from .unit_analysis import ( + compare_unit_dim_to_unit_dim_deps, + compare_unit_dims, + unit_to_unit_dim_deps, +) + + +#################### +# - Unit Dimensions +#################### +class DimsMeta(type): + """Metaclass allowing an implementing (ideally empty) class to access `spu.definitions.dimension_definitions` attributes directly via its own attribute.""" + + def __getattr__(cls, attr: str) -> spu.Dimension: + """Alias for `spu.definitions.dimension_definitions.*` (isn't that a mouthful?). + + Raises: + AttributeError: If the name cannot be found. + """ + if ( + attr in spu.definitions.dimension_definitions.__dir__() + and not attr.startswith('__') + ): + return getattr(spu.definitions.dimension_definitions, attr) + + raise AttributeError(name=attr, obj=Dims) + + +class Dims(metaclass=DimsMeta): + """Access `sympy.physics.units` dimensions with less hassle. + + Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. + + An `AttributeError` is raised if the unit cannot be found in `sympy`. + + Examples: + The objects returned are a direct alias to `sympy`, with less hassle: + ```python + assert Dims.length == ( + sympy.physics.units.definitions.dimension_definitions.length + ) + ``` + """ + + +#################### +# - Physical Type +#################### +class PhysicalType(enum.StrEnum): + """An identifier of unit dimensionality with many useful properties.""" + + # Unitless + NonPhysical = enum.auto() + + # Global + Time = enum.auto() + Angle = enum.auto() + SolidAngle = enum.auto() + ## TODO: Some kind of 3D-specific orientation ex. a quaternion + Freq = enum.auto() + AngFreq = enum.auto() ## rad*hertz + # Cartesian + Length = enum.auto() + Area = enum.auto() + Volume = enum.auto() + # Mechanical + Vel = enum.auto() + Accel = enum.auto() + Mass = enum.auto() + Force = enum.auto() + Pressure = enum.auto() + # Energy + Work = enum.auto() ## joule + Power = enum.auto() ## watt + PowerFlux = enum.auto() ## watt + Temp = enum.auto() + # Electrodynamics + Current = enum.auto() ## ampere + CurrentDensity = enum.auto() + Charge = enum.auto() ## coulomb + Voltage = enum.auto() + Capacitance = enum.auto() ## farad + Impedance = enum.auto() ## ohm + Conductance = enum.auto() ## siemens + Conductivity = enum.auto() ## siemens / length + MFlux = enum.auto() ## weber + MFluxDensity = enum.auto() ## tesla + Inductance = enum.auto() ## henry + EField = enum.auto() + HField = enum.auto() + # Luminal + LumIntensity = enum.auto() + LumFlux = enum.auto() + Illuminance = enum.auto() + + #################### + # - Unit Dimensions + #################### + @functools.cached_property + def unit_dim(self) -> SympyType: + """The unit dimension expression associated with the `PhysicalType`. + + A `PhysicalType` is, in its essence, merely an identifier for a particular unit dimension expression. + """ + PT = PhysicalType + return { + PT.NonPhysical: None, + # Global + PT.Time: Dims.time, + PT.Angle: Dims.angle, + PT.SolidAngle: spu.steradian.dimension, ## MISSING + PT.Freq: Dims.frequency, + PT.AngFreq: Dims.angle * Dims.frequency, + # Cartesian + PT.Length: Dims.length, + PT.Area: Dims.length**2, + PT.Volume: Dims.length**3, + # Mechanical + PT.Vel: Dims.length / Dims.time, + PT.Accel: Dims.length / Dims.time**2, + PT.Mass: Dims.mass, + PT.Force: Dims.force, + PT.Pressure: Dims.pressure, + # Energy + PT.Work: Dims.energy, + PT.Power: Dims.power, + PT.PowerFlux: Dims.power / Dims.length**2, + PT.Temp: Dims.temperature, + # Electrodynamics + PT.Current: Dims.current, + PT.CurrentDensity: Dims.current / Dims.length**2, + PT.Charge: Dims.charge, + PT.Voltage: Dims.voltage, + PT.Capacitance: Dims.capacitance, + PT.Impedance: Dims.impedance, + PT.Conductance: Dims.conductance, + PT.Conductivity: Dims.conductance / Dims.length, + PT.MFlux: Dims.magnetic_flux, + PT.MFluxDensity: Dims.magnetic_density, + PT.Inductance: Dims.inductance, + PT.EField: Dims.voltage / Dims.length, + PT.HField: Dims.current / Dims.length, + # Luminal + PT.LumIntensity: Dims.luminous_intensity, + PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, + PT.Illuminance: Dims.luminous_intensity / Dims.length**2, + }[self] + + @staticproperty + def unit_dims() -> dict[typ.Self, SympyType]: + """All unit dimensions supported by all `PhysicalType`s.""" + return { + physical_type: physical_type.unit_dim + for physical_type in list(PhysicalType) + } + + #################### + # - Convenience Properties + #################### + @functools.cached_property + def default_unit(self) -> list[Unit]: + """Subjective choice of 'default' unit from `self.valid_units`. + + There is no requirement to use this. + """ + PT = PhysicalType + return { + PT.NonPhysical: None, + # Global + PT.Time: spu.picosecond, + PT.Angle: spu.radian, + PT.SolidAngle: spu.steradian, + PT.Freq: spux.terahertz, + PT.AngFreq: spu.radian * spux.terahertz, + # Cartesian + PT.Length: spu.micrometer, + PT.Area: spu.um**2, + PT.Volume: spu.um**3, + # Mechanical + PT.Vel: spu.um / spu.second, + PT.Accel: spu.um / spu.second, + PT.Mass: spu.microgram, + PT.Force: spux.micronewton, + PT.Pressure: spux.millibar, + # Energy + PT.Work: spu.joule, + PT.Power: spu.watt, + PT.PowerFlux: spu.watt / spu.meter**2, + PT.Temp: spu.kelvin, + # Electrodynamics + PT.Current: spu.ampere, + PT.CurrentDensity: spu.ampere / spu.meter**2, + PT.Charge: spu.coulomb, + PT.Voltage: spu.volt, + PT.Capacitance: spu.farad, + PT.Impedance: spu.ohm, + PT.Conductance: spu.siemens, + PT.Conductivity: spu.siemens / spu.micrometer, + PT.MFlux: spu.weber, + PT.MFluxDensity: spu.tesla, + PT.Inductance: spu.henry, + PT.EField: spu.volt / spu.micrometer, + PT.HField: spu.ampere / spu.micrometer, + # Luminal + PT.LumIntensity: spu.candela, + PT.LumFlux: spu.candela * spu.steradian, + PT.Illuminance: spu.candela / spu.meter**2, + }[self] + + #################### + # - Creation + #################### + @staticmethod + def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None: + """Attempt to determine a matching `PhysicalType` from a unit. + + NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`. + + Returns: + The matched `PhysicalType`. + + If none could be matched, then either return `None` (if `optional` is set) or error. + + Raises: + ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. + """ + if unit is None: + return PhysicalType.NonPhysical + + ## TODO_ This enough? + if unit in [spu.radian, spu.degree]: + return PhysicalType.Angle + + unit_dim_deps = unit_to_unit_dim_deps(unit) + if unit_dim_deps is not None: + for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): + if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps): + return physical_type + + if optional: + return None + msg = f'Could not determine PhysicalType for {unit}' + raise ValueError(msg) + + @staticmethod + def from_unit_dim( + unit_dim: SympyType | None, optional: bool = False + ) -> typ.Self | None: + """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`. + + For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`). + To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions. + + Returns: + The matched `PhysicalType`. + + If none could be matched, then either return `None` (if `optional` is set) or error. + + Raises: + ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. + """ + for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): + if compare_unit_dims(unit_dim, candidate_unit_dim): + return physical_type + + if optional: + return None + msg = f'Could not determine PhysicalType for {unit_dim}' + raise ValueError(msg) + + #################### + # - Valid Properties + #################### + @functools.cached_property + def valid_units(self) -> list[Unit]: + """Retrieve an ordered (by subjective usefulness) list of units for this physical type. + + Warnings: + **Altering the order of units hard-breaks backwards compatibility**, since enums based on it only keep an integer index. + + Notes: + The order in which valid units are declared is the exact same order that UI dropdowns display them. + """ + PT = PhysicalType + return { + PT.NonPhysical: [None], + # Global + PT.Time: [ + spu.picosecond, + spux.femtosecond, + spu.nanosecond, + spu.microsecond, + spu.millisecond, + spu.second, + spu.minute, + spu.hour, + spu.day, + ], + PT.Angle: [ + spu.radian, + spu.degree, + ], + PT.SolidAngle: [ + spu.steradian, + ], + PT.Freq: ( + _valid_freqs := [ + spux.terahertz, + spu.hertz, + spux.kilohertz, + spux.megahertz, + spux.gigahertz, + spux.petahertz, + spux.exahertz, + ] + ), + PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], + # Cartesian + PT.Length: ( + _valid_lens := [ + spu.micrometer, + spu.nanometer, + spu.picometer, + spu.angstrom, + spu.millimeter, + spu.centimeter, + spu.meter, + spu.inch, + spu.foot, + spu.yard, + spu.mile, + ] + ), + PT.Area: [_unit**2 for _unit in _valid_lens], + PT.Volume: [_unit**3 for _unit in _valid_lens], + # Mechanical + PT.Vel: [_unit / spu.second for _unit in _valid_lens], + PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], + PT.Mass: [ + spu.kilogram, + spu.electron_rest_mass, + spu.dalton, + spu.microgram, + spu.milligram, + spu.gram, + spu.metric_ton, + ], + PT.Force: [ + spux.micronewton, + spux.nanonewton, + spux.millinewton, + spu.newton, + spu.kg * spu.meter / spu.second**2, + ], + PT.Pressure: [ + spu.bar, + spux.millibar, + spu.pascal, + spux.hectopascal, + spu.atmosphere, + spu.psi, + spu.mmHg, + spu.torr, + ], + # Energy + PT.Work: [ + spu.joule, + spu.electronvolt, + ], + PT.Power: [ + spu.watt, + ], + PT.PowerFlux: [ + spu.watt / spu.meter**2, + ], + PT.Temp: [ + spu.kelvin, + ], + # Electrodynamics + PT.Current: [ + spu.ampere, + ], + PT.CurrentDensity: [ + spu.ampere / spu.meter**2, + ], + PT.Charge: [ + spu.coulomb, + ], + PT.Voltage: [ + spu.volt, + ], + PT.Capacitance: [ + spu.farad, + ], + PT.Impedance: [ + spu.ohm, + ], + PT.Conductance: [ + spu.siemens, + ], + PT.Conductivity: [ + spu.siemens / spu.micrometer, + spu.siemens / spu.meter, + ], + PT.MFlux: [ + spu.weber, + ], + PT.MFluxDensity: [ + spu.tesla, + ], + PT.Inductance: [ + spu.henry, + ], + PT.EField: [ + spu.volt / spu.micrometer, + spu.volt / spu.meter, + ], + PT.HField: [ + spu.ampere / spu.micrometer, + spu.ampere / spu.meter, + ], + # Luminal + PT.LumIntensity: [ + spu.candela, + ], + PT.LumFlux: [ + spu.candela * spu.steradian, + ], + PT.Illuminance: [ + spu.candela / spu.meter**2, + ], + }[self] + + @functools.cached_property + def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]: + """All shapes with physical meaning in the context of a particular unit dimension.""" + PT = PhysicalType + overrides = { + # Cartesian + PT.Length: [None, (2,), (3,)], + # Mechanical + PT.Vel: [None, (2,), (3,)], + PT.Accel: [None, (2,), (3,)], + PT.Force: [None, (2,), (3,)], + # Energy + PT.Work: [None, (2,), (3,)], + PT.PowerFlux: [None, (2,), (3,)], + # Electrodynamics + PT.CurrentDensity: [None, (2,), (3,)], + PT.MFluxDensity: [None, (2,), (3,)], + PT.EField: [None, (2,), (3,)], + PT.HField: [None, (2,), (3,)], + # Luminal + PT.LumFlux: [None, (2,), (3,)], + } + + return overrides.get(self, [None]) + + @functools.cached_property + def valid_mathtypes(self) -> list[MathType]: + """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. + + Generally, all unit quantities are real, in the algebraic mathematical sense. + However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. + This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. + In general, the value is a phasor. + + While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. + + Notes: + - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. + - **Current**/**Voltage**: The imaginary part represents the phase. + This also holds for any downstream units. + - **Charge**: Generally, it is real. + However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: + - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. + + """ + MT = MathType + PT = PhysicalType + overrides = { + PT.NonPhysical: list(MT), ## Support All + # Cartesian + PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + # Mechanical + # Energy + # Electrodynamics + PT.Current: [MT.Real, MT.Complex], ## Im -> Phase + PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase + PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase + PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase + PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance + PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction + PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase + PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.EField: [MT.Real, MT.Complex], ## Im -> Phase + PT.HField: [MT.Real, MT.Complex], ## Im -> Phase + # Luminal + } + + return overrides.get(self, [MT.Real]) + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + if value is PhysicalType.NonPhysical: + return 'Unitless' + return PhysicalType(value).name + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + PT = PhysicalType + return ( + str(self), + PT.to_name(self), + PT.to_name(self), + PT.to_icon(self), + i, + ) + + @functools.cached_property + def color(self): + """A color corresponding to the physical type. + + The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented. + The LLM provided the following rationale for its choices: + + > Non-Physical: Grey signifies neutrality and non-physical nature. + > Global: + > Time: Blue is often associated with calmness and the passage of time. + > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects. + > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time. + > Cartesian: + > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension. + > Mechanical: + > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities. + > Mass: Dark red for the fundamental property. + > Force and Pressure: Shades of red indicating intensity. + > Energy: + > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities. + > Temperature: Yellow for heat. + > Electrodynamics: + > Current and related quantities: Cyan shades indicating flow. + > Voltage, Capacitance: Greenish and blueish cyan for electrical potential. + > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance. + > Magnetic properties: Magenta shades for magnetism. + > Electric Field: Light blue. + > Magnetic Field: Grey, as it can be considered neutral in terms of direction. + > Luminal: + > Luminous properties: Yellows to signify light and illumination. + > + > This color mapping helps maintain intuitive connections for users interacting with these physical types. + """ + PT = PhysicalType + return { + PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical + # Global + PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time + PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle + PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle + PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency + PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency + # Cartesian + PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length + PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area + PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume + # Mechanical + PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity + PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration + PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass + PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force + PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure + # Energy + PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work + PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power + PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux + PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature + # Electrodynamics + PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current + PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density + PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge + PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage + PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance + PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance + PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance + PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity + PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux + PT.MFluxDensity: ( + 0.85, + 0.5, + 0.85, + 1.0, + ), # Light Magenta: Magnetic Flux Density + PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance + PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field + PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field + # Luminal + PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity + PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux + PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance + }[self] diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_expr.py b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py new file mode 100644 index 0000000..cafd421 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py @@ -0,0 +1,337 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import typing as typ +from fractions import Fraction + +import pydantic as pyd +import sympy as sp +import sympy.physics.units as spu +import typing_extensions as typx +from pydantic_core import core_schema as pyd_core_schema + +from . import units as spux +from .sympy_type import SympyType +from .unit_analysis import get_units, uses_units + + +#################### +# - Pydantic "Sympy Expr" +#################### +class _SympyExpr: + """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`. + + Notes: + You probably want to use `SympyExpr`. + + Examples: + To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`: + + ```python + SympyExpr = typx.Annotated[SympyType, _SympyExpr] + + class Spam(pyd.BaseModel): + line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3) + ``` + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: SympyType, + _handler: pyd.GetCoreSchemaHandler, + ) -> pyd_core_schema.CoreSchema: + """Compute a schema that allows `pydantic` to validate a `sympy` type.""" + + def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any: + """Parse and validate a string expression. + + Parameters: + sp_str: A stringified `sympy` object, that will be parsed to a sympy type. + Before use, `isinstance(expr_str, str)` is checked. + If the object isn't a string, then the validation will be skipped. + + Returns: + Either a `sympy` object, if the input is parseable, or the same untouched object. + + Raises: + ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. + """ + # Constrain to String + if not isinstance(sp_str, str): + return sp_str + + # Parse String -> Sympy + try: + expr = sp.sympify(sp_str) + except ValueError as ex: + msg = f'String {sp_str} is not a valid sympy expression' + raise ValueError(msg) from ex + + # Substitute Symbol -> Quantity + return expr.subs(spux.UNIT_BY_SYMBOL) + + def validate_from_pytype( + sp_pytype: int | Fraction | float | complex, + ) -> SympyType | typ.Any: + """Parse and validate a pure Python type. + + Parameters: + sp_str: A stringified `sympy` object, that will be parsed to a sympy type. + Before use, `isinstance(expr_str, str)` is checked. + If the object isn't a string, then the validation will be skipped. + + Returns: + Either a `sympy` object, if the input is parseable, or the same untouched object. + + Raises: + ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. + """ + # Constrain to String + if not isinstance(sp_pytype, int | Fraction | float | complex): + return sp_pytype + + if isinstance(sp_pytype, int): + return sp.Integer(sp_pytype) + if isinstance(sp_pytype, Fraction): + return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) + if isinstance(sp_pytype, float): + return sp.Float(sp_pytype) + + # sp_pytype => Complex + return sp_pytype.real + sp.I * sp_pytype.imag + + sympy_expr_schema = pyd_core_schema.chain_schema( + [ + pyd_core_schema.no_info_plain_validator_function(validate_from_str), + pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), + pyd_core_schema.is_instance_schema(SympyType), + ] + ) + return pyd_core_schema.json_or_python_schema( + json_schema=sympy_expr_schema, + python_schema=sympy_expr_schema, + serialization=pyd_core_schema.plain_serializer_function_ser_schema( + lambda sp_obj: sp.srepr(sp_obj) + ), + ) + + +SympyExpr = typx.Annotated[ + sp.Basic, ## Treat all sympy types as sp.Basic + _SympyExpr, +] +## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. + + +def ConstrSympyExpr( # noqa: N802, PLR0913 + # Features + allow_variables: bool = True, + allow_units: bool = True, + # Structures + allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] + | None = None, + allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, + # Element Class + max_symbols: int | None = None, + allowed_symbols: set[sp.Symbol] | None = None, + allowed_units: set[spu.Quantity] | None = None, + # Shape Class + allowed_matrix_shapes: set[tuple[int, int]] | None = None, +) -> SympyType: + """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`. + + Relies on the `sympy` assumptions system. + See + + Parameters (TBD): + + Returns: + A type that represents a constrained `sympy` expression. + """ + + def validate_expr(expr: SympyType): + if not (isinstance(expr, SympyType),): + msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})" + raise ValueError(msg) + + msgs = set() + + # Validate Feature Class + if (not allow_variables) and (len(expr.free_symbols) > 0): + msgs.add( + f'allow_variables={allow_variables} does not match expression {expr}.' + ) + if (not allow_units) and uses_units(expr): + msgs.add(f'allow_units={allow_units} does not match expression {expr}.') + + # Validate Structure Class + if ( + allowed_sets + and isinstance(expr, sp.Expr) + and not any( + { + 'integer': expr.is_integer, + 'rational': expr.is_rational, + 'real': expr.is_real, + 'complex': expr.is_complex, + }[allowed_set] + for allowed_set in allowed_sets + ) + ): + msgs.add( + f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" + ) + if allowed_structures and not any( + { + 'scalar': True, + 'matrix': isinstance(expr, sp.MatrixBase), + }[allowed_set] + for allowed_set in allowed_structures + ): + msgs.add( + f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" + ) + + # Validate Element Class + if max_symbols and len(expr.free_symbols) > max_symbols: + msgs.add(f'max_symbols={max_symbols} does not match expression {expr}') + if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): + msgs.add( + f'allowed_symbols={allowed_symbols} does not match expression {expr}' + ) + if allowed_units and get_units(expr).issubset(allowed_units): + msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') + + # Validate Shape Class + if ( + allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) + ) and expr.shape not in allowed_matrix_shapes: + msgs.add( + f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' + ) + + # Error or Return + if msgs: + raise ValueError(str(msgs)) + return expr + + return typx.Annotated[ + sp.Basic, + _SympyExpr, + pyd.AfterValidator(validate_expr), + ] + + +#################### +# - Common ConstrSympyExpr +#################### +# Expression +ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_structures={'scalar'}, + allowed_sets={'integer', 'rational', 'real'}, +) +ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_structures={'scalar'}, + allowed_sets={'integer', 'rational', 'real', 'complex'}, +) + +# Symbol +IntSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer'}, + max_symbols=1, +) +RationalSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational'}, + max_symbols=1, +) +RealSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + max_symbols=1, +) +ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + max_symbols=1, +) +Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol + +# Unit +UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension + +## Technically a "unit expression", which includes compound types. +## Support for this is the reason to prefer over raw spu.Quantity. +Unit: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_structures={'scalar'}, +) + +# Number +IntNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer'}, + allowed_structures={'scalar'}, +) +RealNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'scalar'}, +) +ComplexNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + allowed_structures={'scalar'}, +) +Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber + +# Number +PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'scalar'}, +) +PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + allowed_structures={'scalar'}, +) +PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber + +# Vector +Real3DVector: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'matrix'}, + allowed_matrix_shapes={(3, 1)}, +) diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_type.py b/src/blender_maxwell/utils/sympy_extra/sympy_type.py new file mode 100644 index 0000000..ecb736e --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/sympy_type.py @@ -0,0 +1,23 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import sympy as sp +import sympy.physics.units as spu + +#################### +# - Underlying "Sympy Type" +#################### +SympyType = sp.Basic | sp.MatrixBase | spu.Quantity | spu.Dimension diff --git a/src/blender_maxwell/utils/sympy_extra/unit_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py new file mode 100644 index 0000000..3234407 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py @@ -0,0 +1,287 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Functions for characterizaiton, conversion and casting of `sympy` objects that use units.""" + +import functools + +import sympy as sp +import sympy.physics.units as spu + +from . import units as spux +from .parse_cast import sympy_to_python +from .sympy_type import SympyType + + +#################### +# - Unit Characterization +#################### +## TODO: Caching w/srepr'ed expression. +## TODO: An LFU cache could do better than an LRU. +def uses_units(sp_obj: SympyType) -> bool: + """Determines if an expression uses any units. + + Parameters: + expr: The sympy object that may contain units. + + Returns: + Whether or not units are present in the object. + """ + return sp_obj.has(spu.Quantity) + + +## TODO: Caching w/srepr'ed expression. +## TODO: An LFU cache could do better than an LRU. +def get_units(expr: sp.Expr) -> set[spu.Quantity]: + """Finds all units used by the expression, and returns them as a set. + + No information about _the relationship between units_ is exposed. + For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`. + + + Notes: + The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. + + The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node. + + Parameters: + expr: The sympy expression that may contain units. + + Returns: + All units (`spu.Quantity`) used within the expression. + """ + return { + subexpr + for subexpr in sp.postorder_traversal(expr) + if isinstance(subexpr, spu.Quantity) + } + + +#################### +# - Dimensional Characterization +#################### +def unit_dim_to_unit_dim_deps( + unit_dims: SympyType, +) -> dict[spu.dimensions.Dimension, int] | None: + """Normalize an expression to a mapping of its dimensional dependencies. + + Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent. + + Notes: + We adhere to SI unit conventions when determining dimensional dependencies, to ensure that ex. `freq -> 1/time` equivalences are normalized away. + This allows the output of this method to be compared meaningfully, to determine whether two dimensional expressions are equivalent. + + We choose to catch a `TypeError`, for cases where dimensional analysis is impossible (especially `+` or `-` between differing dimensions). + This may have a slight performance penalty. + + Returns: + The dimensional dependencies of the dimensional expression. + + If such a thing makes no sense, ex. if `+` or `-` is present between differing unit dimensions, then return None. + """ + dimsys_SI = spu.systems.si.dimsys_SI + + # Retrieve Dimensional Dependencies + try: + return dimsys_SI.get_dimensional_dependencies(unit_dims) + + # Catch TypeError + ## -> Happens if `+` or `-` is in `unit`. + ## -> Generally, it doesn't make sense to add/subtract differing unit dims. + ## -> Thus, when trying to figure out the unit dimension, there isn't one. + except TypeError: + return None + + +def unit_to_unit_dim_deps( + unit: SympyType, +) -> dict[spu.dimensions.Dimension, int] | None: + """Deduce the dimensional dependencies of a unit. + + Notes: + Using `.subs()` to replace `sp.Quantity`s with `spu.dimensions.Dimension`s seems to result in an expression that absolutely refuses to claim that it has anything other than raw `sp.Symbol`s. + + This is extremely problematic - dimensional analysis relies on the arithmetic properties of proper `Dimension` objects. + + For this reason, though we'd rather have a `unit_to_unit_dims()` function, we have not yet found a way to do this. + Luckily, most of our uses cases seem only to require the dimensional dictionary, which (surprisingly) seems accessible using `unit_dim_to_unit_dim_deps()`. + + """ + # Retrieve Dimensional Dependencies + ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms. + ## -> This is extremely problematic; `Dims` arithmetic has key properties. + ## -> So we have to go all the way to the dimensional dependencies. + ## -> This isn't really respecting the args, but it seems to work :) + return unit_dim_to_unit_dim_deps( + unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)}) + ) + + +def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool: + """Compare the dimensional dependencies of two unit dimensions. + + Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent. + """ + return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps( + unit_dim_r + ) + + +def compare_units_by_unit_dims(unit_l: SympyType, unit_r: SympyType) -> bool: + """Compare two units by their unit dimensions.""" + return unit_to_unit_dim_deps(unit_l) == unit_to_unit_dim_deps(unit_r) + + +def compare_unit_dim_to_unit_dim_deps( + unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int] +) -> bool: + """Compare the dimensional dependencies of unit dimensions to pre-defined unit dimensions.""" + return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps + + +#################### +# - Unit Casting +#################### +def strip_units(sp_obj: SympyType) -> SympyType: + """Strip all units by replacing them to `1`. + + This is a rather unsafe method. + You probably shouldn't use it. + + Warnings: + Absolutely no effort is made to determine whether stripping units is a _meaningful thing to do_. + + For example, using `+` expressions of compatible dimension, but different units, is a clear mistake. + For example, `8*meter + 9*millimeter` strips to `8(1) + 9(1) = 17`, which is a garbage result. + + The **user of this method** must themselves perform appropriate checks on th eobject before stripping units. + + Parameters: + sp_obj: A sympy object that contains unit symbols. + **NOTE**: Unit symbols (from `sympy.physics.units`) are not _free_ symbols, in that they are not unknown. + Nonetheless, they are not _numbers_ either, and thus they cannot be used in a numerical expression. + + Returns: + The sympy object with all unit symbols replaced by `1`, effectively extracting the unitless part of the object. + """ + return sp_obj.subs(spux.UNIT_TO_1) + + +def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType: + """Convert a sympy object to the given unit. + + Supports a unit of `None`, which simply causes the object to have its units stripped. + """ + if unit is None: + return strip_units(sp_obj) + return spu.convert_to(sp_obj, unit) + + # msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' + # raise ValueError(msg) + + +## TODO: Include sympy_to_python in 'scale_to' to match semantics of 'scale_to_unit_system' +## -- Introduce a 'strip_unit +def scale_to_unit( + sp_obj: SympyType, + unit: spu.Quantity | None, + cast_to_pytype: bool = False, + use_jax_array: bool = False, +) -> SympyType: + """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. + + This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. + + Notes: + The unitless output is still an `sp.Expr`, which may contain ex. symbols. + + If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. + In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. + + Parameters: + expr: The unit-containing expression to convert. + unit_to: The unit that is converted to. + + Returns: + The unitless part of `expr`, after scaling the entire expression to `unit`. + + Raises: + ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. + """ + sp_obj_stripped = strip_units(convert_to_unit(sp_obj, unit)) + if cast_to_pytype: + return sympy_to_python( + sp_obj_stripped, + use_jax_array=use_jax_array, + ) + return sp_obj_stripped + + +def scaling_factor( + unit_from: SympyType, unit_to: SympyType +) -> int | float | complex | tuple | None: + """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. + + Parameters: + unit_from: The unit that is converted from. + unit_to: The unit that is converted to. + + Returns: + The numerical scaling factor between the two units. + + If the units are incompatible, then we return None. + + Raises: + ValueError: If the two units don't share a common dimension. + """ + if compare_units_by_unit_dims(unit_from, unit_to): + return scale_to_unit(unit_from, unit_to) + return None + + +@functools.cache +def unit_str_to_unit(unit_str: str, optional: bool = False) -> SympyType | None: + """Determine the `sympy` unit expression that matches the given unit string. + + Parameters: + unit_str: A string parseable with `sp.sympify`, which contains a unit expression. + optional: Whether to return + **NOTE**: `None` is itself a valid "unit", denoting dimensionlessness, in general. + Ensure that appropriate checks are performed to account for this nuance. + + Returns: + The matching `sympy` unit. + + Raises: + ValueError: When no valid unit can be matched to the unit string, and `optional` is `False`. + """ + match unit_str: + # Special-Case 'degree' + ## -> sp.sympify('degree') produces the sp.degree(). + ## -> TODO: Proper Analysis analysis. + case 'degree': + unit = spu.degree + + case _: + unit = sp.sympify(unit_str).subs(spux.UNIT_BY_SYMBOL) + + if uses_units(unit): + return unit + + if optional: + return None + msg = f'No valid unit for unit string {unit_str}' + raise ValueError(msg) diff --git a/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py new file mode 100644 index 0000000..cb30375 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py @@ -0,0 +1,93 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Functions for conversion and casting of `sympy` objects that use units, via unit systems.""" + +import jax +import sympy.physics.units as spu + +from . import units as spux +from .parse_cast import sympy_to_python +from .physical_type import PhysicalType +from .sympy_type import SympyType +from .unit_analysis import get_units +from .unit_systems import UnitSystem + + +#################### +# - Conversion +#################### +def strip_unit_system( + sp_obj: SympyType, unit_system: UnitSystem | None = None +) -> SympyType: + """Strip units occurring in the given unit system from the expression. + + Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". + Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_. + + Notes: + You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. + """ + if unit_system is None: + return sp_obj.subs(spux.UNIT_TO_1) + + return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) + + +def convert_to_unit_system( + sp_obj: SympyType, unit_system: UnitSystem | None +) -> SympyType: + """Convert an expression to the units of a given unit system.""" + if unit_system is None: + return sp_obj + + return spu.convert_to( + sp_obj, + {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, + ) + + +#################### +# - Casting +#################### +def scale_to_unit_system( + sp_obj: SympyType, + unit_system: UnitSystem | None, + use_jax_array: bool = False, +) -> int | float | complex | tuple | jax.Array: + """Convert an expression to the units of a given unit system, then strip all units of the unit system. + + Afterwards, it is converted to an appropriate Python type. + + Notes: + For stability, and performance, reasons, this should only be used at the very last stage. + + Regarding performance: **This is not a fast function**. + + Parameters: + sp_obj: An arbitrary sympy object, presumably with units. + unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. + Note that, in this context, only `unit_system.values()` is used. + + Returns: + An appropriate pure Python type, after scaling to the unit system and stripping all units away. + + If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. + """ + return sympy_to_python( + strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system), + use_jax_array=use_jax_array, + ) diff --git a/src/blender_maxwell/utils/sympy_extra/unit_systems.py b/src/blender_maxwell/utils/sympy_extra/unit_systems.py new file mode 100644 index 0000000..1de298a --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_systems.py @@ -0,0 +1,80 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Defines a common unit system representation, as well as a few of the most common / useful unit systems. + +Attributes: + UnitSystem: Type of a unit system representation, as an exhaustive mapping from `PhysicalType` to a unit expression. + **Compatibility between `PhysicalType` and unit must be manually guaranteed** when defining new unit systems. + UNITS_SI: Pre-defined go-to choice of unit system, which can also be a useful base to build other unit systems on. +""" + +import typing as typ + +import sympy.physics.units as spu + +from . import units as spux +from .physical_type import PhysicalType as PT # noqa: N817 +from .sympy_expr import Unit + +#################### +# - Unit System Representation +#################### +UnitSystem: typ.TypeAlias = dict[PT, Unit] + +#################### +# - Standard Unit Systems +#################### +UNITS_SI: UnitSystem = { + PT.NonPhysical: None, + # Global + PT.Time: spu.second, + PT.Angle: spu.radian, + PT.SolidAngle: spu.steradian, + PT.Freq: spu.hertz, + PT.AngFreq: spu.radian * spu.hertz, + # Cartesian + PT.Length: spu.meter, + PT.Area: spu.meter**2, + PT.Volume: spu.meter**3, + # Mechanical + PT.Vel: spu.meter / spu.second, + PT.Accel: spu.meter / spu.second**2, + PT.Mass: spu.kilogram, + PT.Force: spu.newton, + # Energy + PT.Work: spu.joule, + PT.Power: spu.watt, + PT.PowerFlux: spu.watt / spu.meter**2, + PT.Temp: spu.kelvin, + # Electrodynamics + PT.Current: spu.ampere, + PT.CurrentDensity: spu.ampere / spu.meter**2, + PT.Voltage: spu.volt, + PT.Capacitance: spu.farad, + PT.Impedance: spu.ohm, + PT.Conductance: spu.siemens, + PT.Conductivity: spu.siemens / spu.meter, + PT.MFlux: spu.weber, + PT.MFluxDensity: spu.tesla, + PT.Inductance: spu.henry, + PT.EField: spu.volt / spu.meter, + PT.HField: spu.ampere / spu.meter, + # Luminal + PT.LumIntensity: spu.candela, + PT.LumFlux: spux.lumen, + PT.Illuminance: spu.lux, +} diff --git a/src/blender_maxwell/utils/sympy_extra/units.py b/src/blender_maxwell/utils/sympy_extra/units.py new file mode 100644 index 0000000..9ffd4cf --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/units.py @@ -0,0 +1,77 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import typing as typ + +import sympy as sp +import sympy.physics.units as spu + +#################### +# - Units +#################### +# Time +femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs') +femtosecond.set_global_relative_scale_factor(spu.femto, spu.second) + +# Length +femtometer = fm = spu.Quantity('femtometer', abbrev='fm') +femtometer.set_global_relative_scale_factor(spu.femto, spu.meter) + +# Lum Flux +lumen = lm = spu.Quantity('lumen', abbrev='lm') +lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian) + +# Force +nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816 +nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton) + +micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816 +micronewton.set_global_relative_scale_factor(spu.micro, spu.newton) + +millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816 +micronewton.set_global_relative_scale_factor(spu.milli, spu.newton) + +# Frequency +kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz') +kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) + +megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz') +kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) + +gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz') +gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz) + +terahertz = THz = spu.Quantity('terahertz', abbrev='THz') +terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz) + +petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz') +petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) + +exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') +exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) + +# Pressure +millibar = mbar = spu.Quantity('millibar', abbrev='mbar') +millibar.set_global_relative_scale_factor(spu.milli, spu.bar) + +hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 +hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) + +UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { + unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) +} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} + +UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}