From b51c4f18891b8633a949c80db0d4b921fea4b7f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Sat, 1 Jun 2024 19:08:43 +0200 Subject: [PATCH] feat: deep refactors / fixes We refactored the entire `extra_sympy_units` into a (rather decently sized) proper package, and changed its name. Additionally, we've seperated the operation enums from the math nodes themselves and moved them to dedicated `math_system` package, which is a very big breath of fresh air. I'll need a moment to fix a few typos, but incredibly, this architecture is kind of "just works" (TM) at this point - not like the `FlowKind` refactoring debacle... To go with that, we've greatly streamlined domain handling in `SimSymbol`. We now track a symbolic set as the domain, which is very simple and effective, but comes with some headaches too (`sympy` is fighting me, grumble grumble...) We've also managed to enforce a few unit-conversions in the operate node, and revamp the entire operation validity detection (weirdly hard). The big pain point at the moment is determining the image of functions we apply, on the `sp.Set` domain of `SimSymbol`s. We can express this trivially, but `sympy` simply doesn't care to evaluate it. One can use `SetExpr` and `AccumBounds` to "sometimes work" for some set kinds, but nothing nearly good enough for even our relatively humble needs. Let alone stuff like fourier. We've ended up deciding to hard-code this part of the process by-operation. With domain-specific knowledge and little bit of suffering, we can manually ensure the output domain of every operation makes it to the output symbol. As for "why bother", well, the entire premise of a symbolic nodal math system that is tolerable to use, requires checking the valid domain of the input. We do wish it were optional, but eh. --- .../contracts/bl_socket_types.py | 2 +- .../contracts/flow_kinds/array.py | 2 +- .../contracts/flow_kinds/expr_info.py | 2 +- .../contracts/flow_kinds/flow_kinds.py | 2 +- .../contracts/flow_kinds/info.py | 76 +- .../contracts/flow_kinds/lazy_func.py | 4 +- .../contracts/flow_kinds/lazy_range.py | 34 +- .../contracts/flow_kinds/params.py | 7 +- .../maxwell_sim_nodes/contracts/sim_types.py | 2 +- .../contracts/unit_systems.py | 2 +- .../managed_objs/managed_bl_modifier.py | 2 +- .../maxwell_sim_nodes/math_system/__init__.py | 29 + .../maxwell_sim_nodes/math_system/filter.py | 233 +++ .../maxwell_sim_nodes/math_system/map.py | 365 ++++ .../maxwell_sim_nodes/math_system/operate.py | 476 +++++ .../maxwell_sim_nodes/math_system/reduce.py | 116 ++ .../math_system/transform.py | 336 ++++ .../nodes/analysis/extract_data.py | 2 +- .../nodes/analysis/math/filter_math.py | 240 +-- .../nodes/analysis/math/map_math.py | 356 +--- .../nodes/analysis/math/operate_math.py | 404 +--- .../nodes/analysis/math/transform_math.py | 336 +--- .../maxwell_sim_nodes/nodes/analysis/viz.py | 2 +- .../bound_cond_nodes/absorbing_bound_cond.py | 2 +- .../bounds/bound_cond_nodes/pml_bound_cond.py | 2 +- .../maxwell_sim_nodes/nodes/events.py | 2 +- .../inputs/constants/scientific_constant.py | 2 +- .../nodes/inputs/constants/symbol_constant.py | 120 +- .../file_importers/data_file_importer.py | 2 +- .../maxwell_sim_nodes/nodes/inputs/scene.py | 2 +- .../nodes/inputs/wave_constant.py | 2 +- .../nodes/mediums/library_medium.py | 2 +- .../nodes/monitors/eh_field_monitor.py | 2 +- .../monitors/field_power_flux_monitor.py | 2 +- .../nodes/monitors/permittivity_monitor.py | 2 +- .../file_exporters/data_file_exporter.py | 2 +- .../maxwell_sim_nodes/nodes/outputs/viewer.py | 2 +- .../nodes/simulations/sim_domain.py | 2 +- .../nodes/sources/gaussian_beam_source.py | 2 +- .../nodes/sources/plane_wave_source.py | 2 +- .../nodes/sources/point_dipole_source.py | 2 +- .../nodes/sources/temporal_shape.py | 2 +- .../nodes/structures/geonodes_structure.py | 2 +- .../structures/primitives/box_structure.py | 2 +- .../primitives/cylinder_structure.py | 2 +- .../structures/primitives/sphere_structure.py | 2 +- .../maxwell_sim_nodes/sockets/expr.py | 15 +- .../sockets/maxwell/medium.py | 2 +- .../maxwell_sim_nodes/sockets/physical/pol.py | 2 +- src/blender_maxwell/utils/__init__.py | 4 +- .../utils/extra_sympy_units.py | 1699 ----------------- src/blender_maxwell/utils/image_ops.py | 2 +- src/blender_maxwell/utils/sci_constants.py | 2 +- src/blender_maxwell/utils/serialize.py | 2 +- src/blender_maxwell/utils/sim_symbols.py | 459 +++-- .../utils/sympy_extra/__init__.py | 173 ++ .../utils/sympy_extra/math_type.py | 362 ++++ .../utils/sympy_extra/number_size.py | 148 ++ .../utils/sympy_extra/parse_cast.py | 119 ++ .../utils/sympy_extra/physical_type.py | 644 +++++++ .../utils/sympy_extra/sympy_expr.py | 337 ++++ .../utils/sympy_extra/sympy_type.py | 23 + .../utils/sympy_extra/unit_analysis.py | 287 +++ .../utils/sympy_extra/unit_system_analysis.py | 93 + .../utils/sympy_extra/unit_systems.py | 80 + .../utils/sympy_extra/units.py | 77 + 66 files changed, 4449 insertions(+), 3275 deletions(-) create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py delete mode 100644 src/blender_maxwell/utils/extra_sympy_units.py create mode 100644 src/blender_maxwell/utils/sympy_extra/__init__.py create mode 100644 src/blender_maxwell/utils/sympy_extra/math_type.py create mode 100644 src/blender_maxwell/utils/sympy_extra/number_size.py create mode 100644 src/blender_maxwell/utils/sympy_extra/parse_cast.py create mode 100644 src/blender_maxwell/utils/sympy_extra/physical_type.py create mode 100644 src/blender_maxwell/utils/sympy_extra/sympy_expr.py create mode 100644 src/blender_maxwell/utils/sympy_extra/sympy_type.py create mode 100644 src/blender_maxwell/utils/sympy_extra/unit_analysis.py create mode 100644 src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py create mode 100644 src/blender_maxwell/utils/sympy_extra/unit_systems.py create mode 100644 src/blender_maxwell/utils/sympy_extra/units.py diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py index 244e54b..b5def9c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py @@ -21,7 +21,7 @@ import typing as typ import bpy import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .socket_types import SocketType diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py index a6c2b54..1599f45 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py @@ -22,7 +22,7 @@ import numpy as np import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger log = logger.get(__name__) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py index 412cd05..c588fab 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py @@ -16,7 +16,7 @@ import typing as typ -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from . import FlowKind diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py index 4ff917e..d892ece 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py @@ -19,7 +19,7 @@ import functools import typing as typ from blender_maxwell.contracts import BLEnumElement -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from blender_maxwell.utils.staticproperty import staticproperty diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py index 43dbfe5..58eba20 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py @@ -18,8 +18,8 @@ import dataclasses import functools import typing as typ -from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .array import ArrayFlow from .lazy_range import RangeFlow @@ -89,6 +89,7 @@ class InfoFlow: return None def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None: + """Retrieve the dimension associated with a particular index.""" if idx > 0 and idx < len(self.dims) - 1: return list(self.dims.keys())[idx] return None @@ -179,15 +180,23 @@ class InfoFlow: While that sounds fancy and all, it boils down to: $$ - \texttt{dims} + |\texttt{output}.\texttt{shape}| + |\texttt{dims}| + |\texttt{output}.\texttt{shape}| $$ - Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly. + Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape. Notes: Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`. """ - return len(self.input_mathtypes) + self.output_shape_len + return len(self.dims) + self.output_shape_len + + @functools.cached_property + def is_scalar(self) -> tuple[spux.MathType, int, int]: + """Whether the described object can be described as "scalar". + + True when `self.order == 0`. + """ + return self.order == 0 #################### # - Properties @@ -204,6 +213,58 @@ class InfoFlow: for dim, dim_idx in self.dims.items() } + #################### + # - Operations: Comparison + #################### + def compare_dims_identical(self, other: typ.Self) -> bool: + """Whether that the quantity and properites of all dimension `SimSymbol`s are "identical". + + "Identical" is defined according to the semantics of `SimSymbol.compare()`, which generally means that everything but the exact name and unit are different. + """ + return len(self.dims) == len(other.dims) and all( + dim_l.compare(dim_r) + for dim_l, dim_r in zip(self.dims, other.dims, strict=True) + ) + + def compare_addable( + self, other: typ.Self, allow_differing_unit: bool = False + ) -> bool: + """Whether the two `InfoFlows` can be added/subtracted elementwise. + + Parameters: + allow_differing_unit: When set, + Forces the user to be explicit about specifying + """ + return self.compare_dims_identical(other) and self.output.compare_addable( + other.output, allow_differing_unit=allow_differing_unit + ) + + def compare_multiplicable(self, other: typ.Self) -> bool: + """Whether the two `InfoFlow`s can be multiplied (elementwise). + + - The output `SimSymbol`s must be able to be multiplied. + - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical. + """ + return self.output.compare_multiplicable(other.output) and ( + (len(self.dims) == 0 and self.output.shape_len == 0) + or (len(other.dims) == 0 and other.output.shape_len == 0) + or self.compare_dims_identical(other) + ) + + def compare_exponentiable(self, other: typ.Self) -> bool: + """Whether the two `InfoFlow`s can be exponentiated. + + In general, we follow the rules of the "Hadamard Power" operator, which is also in use in `numpy` broadcasting rules. + + - The output `SimSymbol`s must be able to be exponentiated (mainly, the exponent can't have a unit). + - Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical. + """ + return self.output.compare_exponentiable(other.output) and ( + (len(self.dims) == 0 and self.output.shape_len == 0) + or (len(other.dims) == 0 and other.output.shape_len == 0) + or self.compare_dims_identical(other) + ) + #################### # - Operations: Dimensions #################### @@ -319,6 +380,7 @@ class InfoFlow: op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr], unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr], ) -> spux.SympyExpr: + """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`.""" sym_name = sim_symbols.SimSymbolName.Expr expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy) unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor) @@ -341,11 +403,11 @@ class InfoFlow: cols = self.output.cols match (rows, cols): case (1, 1): - new_output = self.output.set_size(len(last_idx), 1) + new_output = self.output.update(rows=len(last_idx), cols=1) case (_, 1): - new_output = self.output.set_size(rows, len(last_idx)) + new_output = self.output.update(rows=rows, cols=len(last_idx)) case (1, _): - new_output = self.output.set_size(len(last_idx), cols) + new_output = self.output.update(rows=len(last_idx), cols=cols) case (_, _): raise NotImplementedError ## Not yet :) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py index e881c14..025ef71 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py @@ -226,7 +226,7 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow @@ -335,7 +335,7 @@ class FuncFlow(pyd.BaseModel): disallow_jax: Don't use `self.func_jax` to evaluate, even if possible. This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits. """ - if self.supports_jax: + if self.supports_jax and not disallow_jax: return self.func_jax( *params.scaled_func_args(symbol_values), **params.scaled_func_kwargs(symbol_values), diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py index a204eef..a7fd668 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py @@ -24,8 +24,8 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .array import ArrayFlow @@ -95,10 +95,12 @@ class RangeFlow(pyd.BaseModel): stop: spux.ScalarUnitlessRealExpr steps: int = 0 scaling: ScalingMode = ScalingMode.Lin + ## TODO: No support for non-Lin (yet) unit: spux.Unit | None = None symbols: frozenset[sim_symbols.SimSymbol] = frozenset() + ## TODO: No proper support for symbols (yet) # Helper Attributes pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None @@ -112,18 +114,26 @@ class RangeFlow(pyd.BaseModel): steps: int = 50, scaling: ScalingMode | str = ScalingMode.Lin, ) -> typ.Self: - if sym.domain.start.is_infinite or sym.domain.end.is_infinite: - use_steps = 0 - else: - use_steps = steps + if ( + sym.mathtype is not spux.MathType.Complex + and sym.rows == 1 + and sym.cols == 1 + ): + if sym.domain.inf.is_infinite or sym.domain.sup.is_infinite: + _steps = 0 + else: + _steps = steps - return RangeFlow( - start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1), - stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1), - steps=use_steps, - scaling=ScalingMode(scaling), - unit=sym.unit, - ) + return RangeFlow( + start=sym.domain.inf if sym.domain.inf.is_finite else sp.S(-1), + stop=sym.domain.sup if sym.domain.sup.is_finite else sp.S(1), + steps=_steps, + scaling=ScalingMode(scaling), + unit=sym.unit, + ) + + msg = f'RangeFlow is incompatible with SimSymbol {sym}' + raise ValueError(msg) def to_sym( self, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py index 20d6562..b7a9cb7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py @@ -23,7 +23,7 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow @@ -53,11 +53,6 @@ class ParamsFlow(pyd.BaseModel): sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow ] = pyd.Field(default_factory=dict) - @functools.cached_property - def diff_symbols(self) -> set[sim_symbols.SimSymbol]: - """Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments.""" - return {sym for sym in self.symbols if sym.can_diff} - #################### # - Symbols #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py index 16eeca5..9765991 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py @@ -29,7 +29,7 @@ import tidy3d as td from blender_maxwell.contracts import BLEnumElement from blender_maxwell.services import tdcloud -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .flow_kinds.info import InfoFlow diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py index 0dc4752..ce6259b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py @@ -28,7 +28,7 @@ import typing as typ import sympy.physics.units as spu -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux #################### # - Unit Systems diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py index 4f16660..4242efd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py @@ -20,7 +20,7 @@ import typing as typ import bpy -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .. import bl_socket_map diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py new file mode 100644 index 0000000..04c6341 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/__init__.py @@ -0,0 +1,29 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .filter import FilterOperation +from .map import MapOperation +from .operate import BinaryOperation +from .reduce import ReduceOperation +from .transform import TransformOperation + +__all__ = [ + 'FilterOperation', + 'MapOperation', + 'BinaryOperation', + 'ReduceOperation', + 'TransformOperation', +] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py new file mode 100644 index 0000000..6bea59f --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py @@ -0,0 +1,233 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.lax as jlax +import jax.numpy as jnp + +from blender_maxwell.utils import logger, sim_symbols + +from .. import contracts as ct + +log = logger.get(__name__) + + +class FilterOperation(enum.StrEnum): + """Valid operations for the `FilterMathNode`. + + Attributes: + DimToVec: Shift last dimension to output. + DimsToMat: Shift last 2 dimensions to output. + PinLen1: Remove a len(1) dimension. + Pin: Remove a len(n) dimension by selecting a particular index. + Swap: Swap the positions of two dimensions. + """ + + # Slice + Slice = enum.auto() + SliceIdx = enum.auto() + + # Pin + PinLen1 = enum.auto() + Pin = enum.auto() + PinIdx = enum.auto() + + # Dimension + Swap = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + FO = FilterOperation + return { + # Slice + FO.Slice: '≈a[v₁:v₂]', + FO.SliceIdx: '=a[i:j]', + # Pin + FO.PinLen1: 'a[0] → a', + FO.Pin: 'a[v] ⇝ a', + FO.PinIdx: 'a[i] → a', + # Reinterpret + FO.Swap: 'a₁ ↔ a₂', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + FO = FilterOperation + return ( + str(self), + FO.to_name(self), + FO.to_name(self), + FO.to_icon(self), + i, + ) + + #################### + # - Ops from Info + #################### + @staticmethod + def by_info(info: ct.InfoFlow) -> list[typ.Self]: + FO = FilterOperation + operations = [] + + # Slice + if info.dims: + operations.append(FO.SliceIdx) + + # Pin + ## PinLen1 + ## -> There must be a dimension with length 1. + if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]: + operations.append(FO.PinLen1) + + ## Pin | PinIdx + ## -> There must be a dimension, full stop. + if info.dims: + operations += [FO.Pin, FO.PinIdx] + + # Reinterpret + ## Swap + ## -> There must be at least two dimensions. + if len(info.dims) >= 2: # noqa: PLR2004 + operations.append(FO.Swap) + + return operations + + #################### + # - Computed Properties + #################### + @property + def func_args(self) -> list[sim_symbols.SimSymbol]: + FO = FilterOperation + return { + # Pin + FO.Pin: [sim_symbols.idx(None)], + FO.PinIdx: [sim_symbols.idx(None)], + }.get(self, []) + + #################### + # - Methods + #################### + @property + def num_dim_inputs(self) -> None: + FO = FilterOperation + return { + # Slice + FO.Slice: 1, + FO.SliceIdx: 1, + # Pin + FO.PinLen1: 1, + FO.Pin: 1, + FO.PinIdx: 1, + # Reinterpret + FO.Swap: 2, + }[self] + + def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: + FO = FilterOperation + match self: + # Slice + case FO.Slice: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + case FO.SliceIdx: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + # Pin + case FO.PinLen1: + return [ + dim + for dim, dim_idx in info.dims.items() + if not info.has_idx_cont(dim) and len(dim_idx) == 1 + ] + + case FO.Pin: + return info.dims + + case FO.PinIdx: + return [dim for dim in info.dims if not info.has_idx_cont(dim)] + + # Dimension + case FO.Swap: + return info.dims + + return [] + + def are_dims_valid( + self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None + ) -> bool: + """Check whether the given dimension inputs are valid in the context of this operation, and of the information.""" + if self.num_dim_inputs == 1: + return dim_0 in self.valid_dims(info) + + if self.num_dim_inputs == 2: # noqa: PLR2004 + valid_dims = self.valid_dims(info) + return dim_0 in valid_dims and dim_1 in valid_dims + + return False + + #################### + # - UI + #################### + def jax_func( + self, + axis_0: int | None, + axis_1: int | None, + slice_tuple: tuple[int, int, int] | None = None, + ): + FO = FilterOperation + return { + # Pin + FO.Slice: lambda expr: jlax.slice_in_dim( + expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 + ), + FO.SliceIdx: lambda expr: jlax.slice_in_dim( + expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 + ), + # Pin + FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0), + FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), + FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), + # Dimension + FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1), + }[self] + + def transform_info( + self, + info: ct.InfoFlow, + dim_0: sim_symbols.SimSymbol, + dim_1: sim_symbols.SimSymbol, + pin_idx: int | None = None, + slice_tuple: tuple[int, int, int] | None = None, + ): + FO = FilterOperation + return { + FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple), + FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), + # Pin + FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + # Reinterpret + FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py new file mode 100644 index 0000000..d708561 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py @@ -0,0 +1,365 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class MapOperation(enum.StrEnum): + """Valid operations for the `MapMathNode`. + + Attributes: + Real: Compute the real part of the input. + Imag: Compute the imaginary part of the input. + Abs: Compute the absolute value of the input. + Sq: Square the input. + Sqrt: Compute the (principal) square root of the input. + InvSqrt: Compute the inverse square root of the input. + Cos: Compute the cosine of the input. + Sin: Compute the sine of the input. + Tan: Compute the tangent of the input. + Acos: Compute the inverse cosine of the input. + Asin: Compute the inverse sine of the input. + Atan: Compute the inverse tangent of the input. + Norm2: Compute the 2-norm (aka. length) of the input vector. + Det: Compute the determinant of the input matrix. + Cond: Compute the condition number of the input matrix. + NormFro: Compute the frobenius norm of the input matrix. + Rank: Compute the rank of the input matrix. + Diag: Compute the diagonal vector of the input matrix. + EigVals: Compute the eigenvalues vector of the input matrix. + SvdVals: Compute the singular values vector of the input matrix. + Inv: Compute the inverse matrix of the input matrix. + Tra: Compute the transpose matrix of the input matrix. + Qr: Compute the QR-factorized matrices of the input matrix. + Chol: Compute the Cholesky-factorized matrices of the input matrix. + Svd: Compute the SVD-factorized matrices of the input matrix. + """ + + # By Number + Real = enum.auto() + Imag = enum.auto() + Abs = enum.auto() + Sq = enum.auto() + Sqrt = enum.auto() + InvSqrt = enum.auto() + Cos = enum.auto() + Sin = enum.auto() + Tan = enum.auto() + Acos = enum.auto() + Asin = enum.auto() + Atan = enum.auto() + Sinc = enum.auto() + # By Vector + Norm2 = enum.auto() + # By Matrix + Det = enum.auto() + Cond = enum.auto() + NormFro = enum.auto() + Rank = enum.auto() + Diag = enum.auto() + EigVals = enum.auto() + SvdVals = enum.auto() + Inv = enum.auto() + Tra = enum.auto() + Qr = enum.auto() + Chol = enum.auto() + Svd = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + MO = MapOperation + return { + # By Number + MO.Real: 'ℝ(v)', + MO.Imag: 'Im(v)', + MO.Abs: '|v|', + MO.Sq: 'v²', + MO.Sqrt: '√v', + MO.InvSqrt: '1/√v', + MO.Cos: 'cos v', + MO.Sin: 'sin v', + MO.Tan: 'tan v', + MO.Acos: 'acos v', + MO.Asin: 'asin v', + MO.Atan: 'atan v', + MO.Sinc: 'sinc v', + # By Vector + MO.Norm2: '||v||₂', + # By Matrix + MO.Det: 'det V', + MO.Cond: 'κ(V)', + MO.NormFro: '||V||_F', + MO.Rank: 'rank V', + MO.Diag: 'diag V', + MO.EigVals: 'eigvals V', + MO.SvdVals: 'svdvals V', + MO.Inv: 'V⁻¹', + MO.Tra: 'Vt', + MO.Qr: 'qr V', + MO.Chol: 'chol V', + MO.Svd: 'svd V', + }[value] + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + MO = MapOperation + return ( + str(self), + MO.to_name(self), + MO.to_name(self), + MO.to_icon(self), + i, + ) + + #################### + # - Ops from Shape + #################### + @staticmethod + def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]: + ## TODO: By info, not shape. + ## TODO: Check valid domains/mathtypes for some functions. + MO = MapOperation + element_ops = [ + MO.Real, + MO.Imag, + MO.Abs, + MO.Sq, + MO.Sqrt, + MO.InvSqrt, + MO.Cos, + MO.Sin, + MO.Tan, + MO.Acos, + MO.Asin, + MO.Atan, + MO.Sinc, + ] + + match (info.output.rows, info.output.cols): + case (1, 1): + return element_ops + + case (_, 1): + return [*element_ops, MO.Norm2] + + case (rows, cols) if rows == cols: + ## TODO: Check hermitian/posdef for cholesky. + ## - Can we even do this with just the output symbol approach? + return [ + *element_ops, + MO.Det, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.Diag, + MO.EigVals, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Qr, + MO.Chol, + MO.Svd, + ] + + case (rows, cols): + return [ + *element_ops, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Svd, + ] + + return [] + + #################### + # - Function Properties + #################### + @property + def sp_func(self): + MO = MapOperation + return { + # By Number + MO.Real: lambda expr: sp.re(expr), + MO.Imag: lambda expr: sp.im(expr), + MO.Abs: lambda expr: sp.Abs(expr), + MO.Sq: lambda expr: expr**2, + MO.Sqrt: lambda expr: sp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr), + MO.Cos: lambda expr: sp.cos(expr), + MO.Sin: lambda expr: sp.sin(expr), + MO.Tan: lambda expr: sp.tan(expr), + MO.Acos: lambda expr: sp.acos(expr), + MO.Asin: lambda expr: sp.asin(expr), + MO.Atan: lambda expr: sp.atan(expr), + MO.Sinc: lambda expr: sp.sinc(expr), + # By Vector + # Vector -> Number + MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0], + # By Matrix + # Matrix -> Number + MO.Det: lambda expr: sp.det(expr), + MO.Cond: lambda expr: expr.condition_number(), + MO.NormFro: lambda expr: expr.norm(ord='fro'), + MO.Rank: lambda expr: expr.rank(), + # Matrix -> Vec + MO.Diag: lambda expr: expr.diagonal(), + MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())), + MO.SvdVals: lambda expr: expr.singular_values(), + # Matrix -> Matrix + MO.Inv: lambda expr: expr.inv(), + MO.Tra: lambda expr: expr.T, + # Matrix -> Matrices + MO.Qr: lambda expr: expr.QRdecomposition(), + MO.Chol: lambda expr: expr.cholesky(), + MO.Svd: lambda expr: expr.singular_value_decomposition(), + }[self] + + @property + def jax_func(self): + MO = MapOperation + return { + # By Number + MO.Real: lambda expr: jnp.real(expr), + MO.Imag: lambda expr: jnp.imag(expr), + MO.Abs: lambda expr: jnp.abs(expr), + MO.Sq: lambda expr: jnp.square(expr), + MO.Sqrt: lambda expr: jnp.sqrt(expr), + MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr), + MO.Cos: lambda expr: jnp.cos(expr), + MO.Sin: lambda expr: jnp.sin(expr), + MO.Tan: lambda expr: jnp.tan(expr), + MO.Acos: lambda expr: jnp.acos(expr), + MO.Asin: lambda expr: jnp.asin(expr), + MO.Atan: lambda expr: jnp.atan(expr), + MO.Sinc: lambda expr: jnp.sinc(expr), + # By Vector + # Vector -> Number + MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1), + # By Matrix + # Matrix -> Number + MO.Det: lambda expr: jnp.linalg.det(expr), + MO.Cond: lambda expr: jnp.linalg.cond(expr), + MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'), + MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr), + # Matrix -> Vec + MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1), + MO.EigVals: lambda expr: jnp.linalg.eigvals(expr), + MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr), + # Matrix -> Matrix + MO.Inv: lambda expr: jnp.linalg.inv(expr), + MO.Tra: lambda expr: jnp.matrix_transpose(expr), + # Matrix -> Matrices + MO.Qr: lambda expr: jnp.linalg.qr(expr), + MO.Chol: lambda expr: jnp.linalg.cholesky(expr), + MO.Svd: lambda expr: jnp.linalg.svd(expr), + }[self] + + def transform_info(self, info: ct.InfoFlow): + MO = MapOperation + + return { + # By Number + MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Sq: lambda: info, + MO.Sqrt: lambda: info, + MO.InvSqrt: lambda: info, + MO.Cos: lambda: info, + MO.Sin: lambda: info, + MO.Tan: lambda: info, + MO.Acos: lambda: info, + MO.Asin: lambda: info, + MO.Atan: lambda: info, + MO.Sinc: lambda: info, + # By Vector + MO.Norm2: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + # By Matrix + MO.Det: lambda: info.update_output( + rows=1, + cols=1, + ), + MO.Cond: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, + ), + MO.NormFro: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + MO.Rank: lambda: info.update_output( + mathtype=spux.MathType.Integer, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, + # Interval + interval_finite_re=(0, sim_symbols.int_max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + # Matrix -> Vector ## TODO: ALL OF THESE + MO.Diag: lambda: info, + MO.EigVals: lambda: info, + MO.SvdVals: lambda: info, + # Matrix -> Matrix ## TODO: ALL OF THESE + MO.Inv: lambda: info, + MO.Tra: lambda: info, + # Matrix -> Matrices ## TODO: ALL OF THESE + MO.Qr: lambda: info, + MO.Chol: lambda: info, + MO.Svd: lambda: info, + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py new file mode 100644 index 0000000..cf3d228 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py @@ -0,0 +1,476 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp +import sympy.physics.quantum as spq +import sympy.physics.units as spu + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType: + """Implement the Hadamard Power. + + Follows the specification in , which also conforms to `numpy` broadcasting rules for `**` on `np.ndarray`. + """ + match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)): + case (False, False): + msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}" + raise ValueError(msg) + + case (True, False): + return lhs.applyfunc(lambda el: el**rhs) + + case (False, True): + return rhs.applyfunc(lambda el: lhs**el) + + case (True, True) if lhs.shape == rhs.shape: + common_shape = lhs.shape + return sp.ImmutableMatrix( + *common_shape, lambda i, j: lhs[i, j] ** rhs[i, j] + ) + + case _: + msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}' + raise ValueError(msg) + + +class BinaryOperation(enum.StrEnum): + """Valid operations for the `OperateMathNode`. + + Attributes: + Mul: Scalar multiplication. + Div: Scalar division. + Pow: Scalar exponentiation. + Add: Elementwise addition. + Sub: Elementwise subtraction. + HadamMul: Elementwise multiplication (hadamard product). + HadamPow: Principled shape-aware exponentiation (hadamard power). + Atan2: Quadrant-respecting 2D arctangent. + VecVecDot: Dot product for identically shaped vectors w/transpose. + Cross: Cross product between identically shaped 3D vectors. + VecVecOuter: Vector-vector outer product. + LinSolve: Solve a linear system. + LsqSolve: Minimize error of an underdetermined linear system. + VecMatOuter: Vector-matrix outer product. + MatMatDot: Matrix-matrix dot product. + """ + + # Number | Number + Mul = enum.auto() + Div = enum.auto() + Pow = enum.auto() + + # Elements | Elements + Add = enum.auto() + Sub = enum.auto() + HadamMul = enum.auto() + HadamPow = enum.auto() + HadamDiv = enum.auto() + Atan2 = enum.auto() + + # Vector | Vector + VecVecDot = enum.auto() + Cross = enum.auto() + VecVecOuter = enum.auto() + + # Matrix | Vector + LinSolve = enum.auto() + LsqSolve = enum.auto() + + # Vector | Matrix + VecMatOuter = enum.auto() + + # Matrix | Matrix + MatMatDot = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + BO = BinaryOperation + return { + # Number | Number + BO.Mul: 'ℓ · r', + BO.Div: 'ℓ / r', + BO.Pow: 'ℓ ^ r', ## Also for square-matrix powers. + # Elements | Elements + BO.Add: 'ℓ + r', + BO.Sub: 'ℓ - r', + BO.HadamMul: '𝐋 ⊙ 𝐑', + BO.HadamDiv: '𝐋 ⊙/ 𝐑', + BO.HadamPow: '𝐥 ⊙^ 𝐫', + BO.Atan2: 'atan2(ℓ:x, r:y)', + # Vector | Vector + BO.VecVecDot: '𝐥 · 𝐫', + BO.Cross: 'cross(𝐥,𝐫)', + BO.VecVecOuter: '𝐥 ⊗ 𝐫', + # Matrix | Vector + BO.LinSolve: '𝐋 ∖ 𝐫', + BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂', + # Vector | Matrix + BO.VecMatOuter: '𝐋 ⊗ 𝐫', + # Matrix | Matrix + BO.MatMatDot: '𝐋 · 𝐑', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + BO = BinaryOperation + return ( + str(self), + BO.to_name(self), + BO.to_name(self), + BO.to_icon(self), + i, + ) + + def bl_enum_elements( + self, info_l: ct.InfoFlow, info_r: ct.InfoFlow + ) -> list[ct.BLEnumElement]: + """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s. + + Returns a `bpy.props.EnumProperty.items`-compatible list. + """ + return [ + operation.bl_enum_element(i) + for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r)) + ] + + #################### + # - Ops from Shape + #################### + @staticmethod + def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]: + """Deduce valid binary operations from the shapes of the inputs.""" + BO = BinaryOperation + ops = [] + + # Add/Sub + if info_l.compare_addable(info_r, allow_differing_unit=True): + ops += [BO.Add, BO.Sub] + + # Mul/Div + ## -> Mul is ambiguous; we differentiate Hadamard and Standard. + ## -> Div additionally requires non-zero guarantees. + if info_l.compare_multiplicable(info_r): + match (info_l.order, info_r.order, info_r.output.is_nonzero): + case (ordl, ordr, True) if ordl == 0 and ordr == 0: + ops += [BO.Mul, BO.Div] + case (ordl, ordr, True) if ordl > 0 and ordr == 0: + ops += [BO.Mul, BO.Div] + case (ordl, ordr, True) if ordl == 0 and ordr > 0: + ops += [BO.Mul] + case (ordl, ordr, True) if ordl > 0 and ordr > 0: + ops += [BO.HadamMul, BO.HadamDiv] + + case (ordl, ordr, False) if ordl == 0 and ordr == 0: + ops += [BO.Mul] + case (ordl, ordr, False) if ordl > 0 and ordr == 0: + ops += [BO.Mul] + case (ordl, ordr, True) if ordl == 0 and ordr > 0: + ops += [BO.Mul] + case (ordl, ordr, False) if ordl > 0 and ordr > 0: + ops += [BO.HadamMul] + + # Pow + ## -> We distinguish between "Hadamard Power" and "Power". + ## -> For scalars, they are the same (but we only expose "power"). + ## -> For matrices, square matrices can be exp'ed by int powers. + ## -> Any other combination is well-defined by the Hadamard Power. + if info_l.compare_exponentiable(info_r): + match (info_l.order, info_r.order, info_r.output.mathtype): + case (ordl, ordr, _) if ordl == 0 and ordr == 0: + ops += [BO.Pow] + + case (ordl, ordr, spux.MathType.Integer) if ( + ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols + ): + ops += [BO.Pow, BO.HadamPow] + + case _: + ops += [BO.HadamPow] + + # Operations by-Output Length + match ( + info_l.output.shape_len, + info_r.output.shape_len, + ): + # Number | Number + case (0, 0) if info_l.is_scalar and info_r.is_scalar: + # atan2: PhysicalType Must Both be Length | NonPhysical + ## -> atan2() produces radians from Cartesian coordinates. + ## -> This wouldn't make sense on non-Length / non-Unitless. + if ( + info_l.output.physical_type is spux.PhysicalType.Length + and info_r.output.physical_type is spux.PhysicalType.Length + ) or ( + info_l.output.physical_type is spux.PhysicalType.NonPhysical + and info_l.output.unit is None + and info_r.output.physical_type is spux.PhysicalType.NonPhysical + and info_r.output.unit is None + ): + ops += [BO.Atan2] + + return ops + + # Vector | Vector + case (1, 1) if info_l.compare_dims_identical(info_r): + outl = info_l.output + outr = info_r.output + + # 1D Orders: Outer Product is Valid + ## -> We can't do per-element outer product. + ## -> However, it's still super useful on its own. + if info_l.order == 1 and info_r.order == 1: + ops += [BO.VecVecOuter] + + # Vector | Vector + if outl.rows > outl.cols and outr.rows > outr.cols: + ops += [BO.VecVecDot] + + # Covector | Vector + if outl.rows < outl.cols and outr.rows > outr.cols: + ops += [BO.MatMatDot] + + # Vector | Covector + if outl.rows > outl.cols and outr.rows < outr.cols: + ops += [BO.MatMatDot] + + # Covector | Covector + if outl.rows < outl.cols and outr.rows < outr.cols: + ops += [BO.VecVecDot] + + # Cross Product + ## -> Works great element-wise. + ## -> Enforce that both are 3x1 or 1x3. + ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross + if (outl.rows == 3 and outr.rows == 3) or ( + outl.cols == 3 and outl.cols == 3 + ): + ops += [BO.Cross] + + # Vector | Matrix + ## -> We can't do per-element outer product. + ## -> However, it's still super useful on its own. + case (1, 2) if info_l.compare_dims_identical( + info_r + ) and info_l.order == 1 and info_r.order == 2: + ops += [BO.VecMatOuter] + + # Matrix | Vector + case (2, 1) if info_l.compare_dims_identical(info_r): + # Mat-Vec Dot: Enforce RHS Column Vector + if outr.rows > outl.cols: + ops += [BO.MatMatDot] + + ops += [BO.LinSolve, BO.LsqSolve] + + ## Matrix | Matrix + case (2, 2): + ops += [BO.MatMatDot] + + return ops + + #################### + # - Function Properties + #################### + @property + def sp_func(self): + """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs.""" + BO = BinaryOperation + + ## TODO: Make this compatible with sp.Matrix inputs + return { + # Number | Number + BO.Mul: lambda exprs: exprs[0] * exprs[1], + BO.Div: lambda exprs: exprs[0] / exprs[1], + BO.Pow: lambda exprs: exprs[0] ** exprs[1], + # Elements | Elements + BO.Add: lambda exprs: exprs[0] + exprs[1], + BO.Sub: lambda exprs: exprs[0] - exprs[1], + BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]), + BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), + BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), + # Vector | Vector + BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0], + BO.Cross: lambda exprs: exprs[0].cross(exprs[1]), + BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T, + # Matrix | Vector + BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]), + BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]), + # Vector | Matrix + BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]), + # Matrix | Matrix + BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1], + }[self] + + @property + def unit_func(self): + """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output.""" + BO = BinaryOperation + + ## TODO: Make this compatible with sp.Matrix inputs + return { + # Number | Number + BO.Mul: BO.Mul.sp_func, + BO.Div: BO.Div.sp_func, + BO.Pow: BO.Pow.sp_func, + # Elements | Elements + BO.Add: BO.Add.sp_func, + BO.Sub: BO.Sub.sp_func, + BO.HadamMul: BO.Mul.sp_func, + # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), + BO.Atan2: lambda _: spu.radian, + # Vector | Vector + BO.VecVecDot: BO.Mul.sp_func, + BO.Cross: BO.Mul.sp_func, + BO.VecVecOuter: BO.Mul.sp_func, + # Matrix | Vector + ## -> A,b in Ax = b have units, and the equality must hold. + ## -> Therefore, A \ b must have the units [b]/[A]. + BO.LinSolve: lambda exprs: exprs[1] / exprs[0], + BO.LsqSolve: lambda exprs: exprs[1] / exprs[0], + # Vector | Matrix + BO.VecMatOuter: BO.Mul.sp_func, + # Matrix | Matrix + BO.MatMatDot: BO.Mul.sp_func, + }[self] + + @property + def jax_func(self): + """Deduce an appropriate jax-based function that implements the binary operation for array inputs.""" + ## TODO: Scale the units of one side to the other. + BO = BinaryOperation + + return { + # Number | Number + BO.Mul: lambda exprs: exprs[0] * exprs[1], + BO.Div: lambda exprs: exprs[0] / exprs[1], + BO.Pow: lambda exprs: exprs[0] ** exprs[1], + # Elements | Elements + BO.Add: lambda exprs: exprs[0] + exprs[1], + BO.Sub: lambda exprs: exprs[0] - exprs[1], + BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]), + BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise( + exprs[1].applyfunc(lambda el: 1 / el) + ), + BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]), + BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]), + # Vector | Vector + BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]), + BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]), + BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), + # Matrix | Vector + BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), + BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), + # Vector | Matrix + BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), + # Matrix | Matrix + BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]), + }[self] + + #################### + # - Transforms + #################### + def transform_funcs(self, func_l: ct.FuncFlow, func_r: ct.FuncFlow) -> ct.FuncFlow: + """Transform two input functions according to the current operation.""" + BO = BinaryOperation + + # Add/Sub: Normalize Unit of RHS to LHS + ## -> We can only add/sub identical units. + ## -> To be nice, we only require identical PhysicalType. + ## -> The result of a binary operation should have one unit. + if self is BO.Add or self is BO.Sub: + norm_func_r = func_r.scale_to_unit(func_l.func_output.unit) + else: + norm_func_r = func_r + + return (func_l, norm_func_r).compose_within( + self.jax_func, + enclosing_func_output=self.transform_outputs( + func_l.func_output, norm_func_r.func_output + ), + supports_jax=True, + ) + + def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): + """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression. + + Warnings: + `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r). + + If not, bad things will happen. + """ + return info_l.operate_output( + info_r, + lambda a, b: self.sp_func([a, b]), + lambda a, b: self.unit_func([a, b]), + ) + + #################### + # - InfoFlow Transform + #################### + def transform_outputs( + self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol + ) -> sim_symbols.SimSymbol: + # TO = TransformOperation + return None + # match self: + # # Number | Number + # case TO.Mul: + # return + # case TO.Div: + # case TO.Pow: + + # # Elements | Elements + # Add = enum.auto() + # Sub = enum.auto() + # HadamMul = enum.auto() + # HadamPow = enum.auto() + # HadamDiv = enum.auto() + # Atan2 = enum.auto() + + # # Vector | Vector + # VecVecDot = enum.auto() + # Cross = enum.auto() + # VecVecOuter = enum.auto() + + # # Matrix | Vector + # LinSolve = enum.auto() + # LsqSolve = enum.auto() + + # # Vector | Matrix + # VecMatOuter = enum.auto() + + # # Matrix | Matrix + # MatMatDot = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py new file mode 100644 index 0000000..4d26985 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py @@ -0,0 +1,116 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import sympy as sp + +from blender_maxwell.utils import logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class ReduceOperation(enum.StrEnum): + # Summary + Count = enum.auto() + + # Statistics + Mean = enum.auto() + Std = enum.auto() + Var = enum.auto() + + StdErr = enum.auto() + + Min = enum.auto() + Q25 = enum.auto() + Median = enum.auto() + Q75 = enum.auto() + Max = enum.auto() + + Mode = enum.auto() + + # Reductions + Sum = enum.auto() + Prod = enum.auto() + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + RO = ReduceOperation + return { + # Summary + RO.Count: '# [a]', + RO.Mode: 'mode [a]', + # Statistics + RO.Mean: 'μ [a]', + RO.Std: 'σ [a]', + RO.Var: 'σ² [a]', + RO.StdErr: 'stderr [a]', + RO.Min: 'min [a]', + RO.Q25: 'q₂₅ [a]', + RO.Median: 'median [a]', + RO.Q75: 'q₇₅ [a]', + RO.Min: 'max [a]', + # Reductions + RO.Sum: 'sum [a]', + RO.Prod: 'prod [a]', + }[value] + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + RO = ReduceOperation + return ( + str(self), + RO.to_name(self), + RO.to_name(self), + RO.to_icon(self), + i, + ) + + #################### + # - Derivation + #################### + @staticmethod + def from_info(info: ct.InfoFlow) -> list[typ.Self]: + """Derive valid reduction operations from the `InfoFlow` of the operand.""" + pass + + #################### + # - Composable Functions + #################### + @property + def jax_func(self): + RO = ReduceOperation + return {}[self] + + #################### + # - Transforms + #################### + def transform_info(self, info: ct.InfoFlow): + pass diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py new file mode 100644 index 0000000..499c082 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py @@ -0,0 +1,336 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import jax.numpy as jnp +import jaxtyping as jtyp + +from blender_maxwell.utils import logger, sci_constants, sim_symbols +from blender_maxwell.utils import sympy_extra as spux + +from .. import contracts as ct + +log = logger.get(__name__) + + +class TransformOperation(enum.StrEnum): + """Valid operations for the `TransformMathNode`. + + Attributes: + FreqToVacWL: Transform an frequency dimension to vacuum wavelength. + VacWLToFreq: Transform a vacuum wavelength dimension to frequency. + ConvertIdxUnit: Convert the unit of a dimension to a compatible unit. + SetIdxUnit: Set all properties of a dimension. + FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it. + **For 2D integer-indexed data only**. + + IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type. + DimToVec: Fold the last dimension into the scalar output, creating a vector output type. + DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type. + FT: Compute the 1D fourier transform along a dimension. + New dimensional bounds are computing using the Nyquist Limit. + For higher dimensions, simply repeat along more dimensions. + InvFT1D: Compute the inverse 1D fourier transform along a dimension. + New dimensional bounds are computing using the Nyquist Limit. + For higher dimensions, simply repeat along more dimensions. + """ + + # Covariant Transform + FreqToVacWL = enum.auto() + VacWLToFreq = enum.auto() + ConvertIdxUnit = enum.auto() + SetIdxUnit = enum.auto() + FirstColToFirstIdx = enum.auto() + + # Fold + IntDimToComplex = enum.auto() + DimToVec = enum.auto() + DimsToMat = enum.auto() + + # Fourier + FT1D = enum.auto() + InvFT1D = enum.auto() + + # TODO: Affine + ## TODO + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + TO = TransformOperation + return { + # Covariant Transform + TO.FreqToVacWL: '𝑓 → λᵥ', + TO.VacWLToFreq: 'λᵥ → 𝑓', + TO.ConvertIdxUnit: 'Convert Dim', + TO.SetIdxUnit: 'Set Dim', + TO.FirstColToFirstIdx: '1st Col → 1st Dim', + # Fold + TO.IntDimToComplex: '→ ℂ', + TO.DimToVec: '→ Vector', + TO.DimsToMat: '→ Matrix', + ## TODO: Vector to new last-dim integer + ## TODO: Matrix to two last-dim integers + # Fourier + TO.FT1D: 'FT', + TO.InvFT1D: 'iFT', + }[value] + + @property + def name(self) -> str: + return TransformOperation.to_name(self) + + @staticmethod + def to_icon(_: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + TO = TransformOperation + return ( + str(self), + TO.to_name(self), + TO.to_name(self), + TO.to_icon(self), + i, + ) + + #################### + # - Methods + #################### + def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: + TO = TransformOperation + match self: + case TO.FreqToVacWL: + return [ + dim + for dim in info.dims + if dim.physical_type is spux.PhysicalType.Freq + ] + + case TO.VacWLToFreq: + return [ + dim + for dim in info.dims + if dim.physical_type is spux.PhysicalType.Length + ] + + case TO.ConvertIdxUnit: + return [ + dim + for dim in info.dims + if not info.has_idx_labels(dim) + and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None + ] + + case TO.SetIdxUnit: + return [dim for dim in info.dims if not info.has_idx_labels(dim)] + + ## ColDimToComplex: Implicit Last Dimension + ## DimToVec: Implicit Last Dimension + ## DimsToMat: Implicit Last 2 Dimensions + + case TO.FT1D | TO.InvFT1D: + # Filter by Axis Uniformity + ## -> FT requires uniform axis (aka. must be RangeFlow). + ## -> NOTE: If FT isn't popping up, check ExtractDataNode. + return [dim for dim in info.dims if info.is_idx_uniform(dim)] + + return [] + + @staticmethod + def by_info(info: ct.InfoFlow) -> list[typ.Self]: + TO = TransformOperation + operations = [] + + # Covariant Transform + ## Freq -> VacWL + if TO.FreqToVacWL.valid_dims(info): + operations += [TO.FreqToVacWL] + + ## VacWL -> Freq + if TO.VacWLToFreq.valid_dims(info): + operations += [TO.VacWLToFreq] + + ## Convert Index Unit + if TO.ConvertIdxUnit.valid_dims(info): + operations += [TO.ConvertIdxUnit] + + if TO.SetIdxUnit.valid_dims(info): + operations += [TO.SetIdxUnit] + + ## Column to First Index (Array) + if ( + len(info.dims) == 2 # noqa: PLR2004 + and info.first_dim.mathtype is spux.MathType.Integer + and info.last_dim.mathtype is spux.MathType.Integer + and info.output.shape_len == 0 + ): + operations += [TO.FirstColToFirstIdx] + + # Fold + ## Last Dim -> Complex + if ( + len(info.dims) >= 1 + and ( + info.output.mathtype + in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real] + ) + and info.last_dim.mathtype is spux.MathType.Integer + and info.has_idx_labels(info.last_dim) + and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004 + ): + operations += [TO.IntDimToComplex] + + ## Last Dim -> Vector + if len(info.dims) >= 1 and info.output.shape_len == 0: + operations += [TO.DimToVec] + + ## Last Dim -> Matrix + if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004 + operations += [TO.DimsToMat] + + # Fourier + if TO.FT1D.valid_dims(info): + operations += [TO.FT1D] + + if TO.InvFT1D.valid_dims(info): + operations += [TO.InvFT1D] + + return operations + + #################### + # - Function Properties + #################### + def jax_func(self, axis: int | None = None): + TO = TransformOperation + return { + # Covariant Transform + ## -> Freq <-> WL is a rescale (noop) AND flip (not noop). + TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis), + TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis), + TO.ConvertIdxUnit: lambda expr: expr, + TO.SetIdxUnit: lambda expr: expr, + TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1), + # Fold + ## -> To Complex: This should generally be a no-op. + TO.IntDimToComplex: lambda expr: jnp.squeeze( + expr.view(dtype=jnp.complex64), axis=-1 + ), + TO.DimToVec: lambda expr: expr, + TO.DimsToMat: lambda expr: expr, + # Fourier + TO.FT1D: lambda expr: jnp.fft(expr, axis=axis), + TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis), + }[self] + + def transform_info( + self, + info: ct.InfoFlow, + dim: sim_symbols.SimSymbol | None = None, + data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None, + new_dim_name: str | None = None, + unit: spux.Unit | None = None, + physical_type: spux.PhysicalType | None = None, + ) -> ct.InfoFlow: + TO = TransformOperation + return { + # Covariant Transform + TO.FreqToVacWL: lambda: info.replace_dim( + (f_dim := dim), + sim_symbols.wl(unit), + info.dims[f_dim].rescale( + lambda el: sci_constants.vac_speed_of_light / el, + reverse=True, + new_unit=unit, + ), + ), + TO.VacWLToFreq: lambda: info.replace_dim( + (wl_dim := dim), + sim_symbols.freq(unit), + info.dims[wl_dim].rescale( + lambda el: sci_constants.vac_speed_of_light / el, + reverse=True, + new_unit=unit, + ), + ), + TO.ConvertIdxUnit: lambda: info.replace_dim( + dim, + dim.update(unit=unit), + ( + info.dims[dim].rescale_to_unit(unit) + if info.has_idx_discrete(dim) + else None ## Continuous -- dim SimSymbol already scaled + ), + ), + TO.SetIdxUnit: lambda: info.replace_dim( + dim, + dim.update( + sym_name=new_dim_name, + physical_type=physical_type, + unit=unit, + ), + ( + info.dims[dim].correct_unit(unit) + if info.has_idx_discrete(dim) + else None ## Continuous -- dim SimSymbol already scaled + ), + ), + TO.FirstColToFirstIdx: lambda: info.replace_dim( + info.first_dim, + info.first_dim.update( + sym_name=new_dim_name, + mathtype=spux.MathType.from_jax_array(data_col), + physical_type=physical_type, + unit=unit, + ), + ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)), + ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)), + # Fold + TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( + mathtype=spux.MathType.Complex + ), + TO.DimToVec: lambda: info.fold_last_input(), + TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), + # Fourier + TO.FT1D: lambda: info.replace_dim( + dim, + [ + # FT'ed Unit: Reciprocal of the Original Unit + dim.update( + unit=1 / dim.unit if dim.unit is not None else 1 + ), ## TODO: Okay to not scale interval? + # FT'ed Bounds: Reciprocal of the Original Unit + info.dims[dim].bound_fourier_transform, + ], + ), + TO.InvFT1D: lambda: info.replace_dim( + info.last_dim, + [ + # FT'ed Unit: Reciprocal of the Original Unit + dim.update( + unit=1 / dim.unit if dim.unit is not None else 1 + ), ## TODO: Okay to not scale interval? + # FT'ed Bounds: Reciprocal of the Original Unit + ## -> Note the midpoint may revert to 0. + ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more. + info.dims[dim].bound_inv_fourier_transform, + ], + ), + }[self]() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index e036857..a2f54db 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -26,7 +26,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index eddeff4..6c02ec1 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -20,225 +20,15 @@ import enum import typing as typ import bpy -import jax.lax as jlax -import jax.numpy as jnp import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events -log = logger.get(__name__) - - -class FilterOperation(enum.StrEnum): - """Valid operations for the `FilterMathNode`. - - Attributes: - DimToVec: Shift last dimension to output. - DimsToMat: Shift last 2 dimensions to output. - PinLen1: Remove a len(1) dimension. - Pin: Remove a len(n) dimension by selecting a particular index. - Swap: Swap the positions of two dimensions. - """ - - # Slice - Slice = enum.auto() - SliceIdx = enum.auto() - - # Pin - PinLen1 = enum.auto() - Pin = enum.auto() - PinIdx = enum.auto() - - # Dimension - Swap = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - FO = FilterOperation - return { - # Slice - FO.Slice: '≈a[v₁:v₂]', - FO.SliceIdx: '=a[i:j]', - # Pin - FO.PinLen1: 'a[0] → a', - FO.Pin: 'a[v] ⇝ a', - FO.PinIdx: 'a[i] → a', - # Reinterpret - FO.Swap: 'a₁ ↔ a₂', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - FO = FilterOperation - return ( - str(self), - FO.to_name(self), - FO.to_name(self), - FO.to_icon(self), - i, - ) - - #################### - # - Ops from Info - #################### - @staticmethod - def by_info(info: ct.InfoFlow) -> list[typ.Self]: - FO = FilterOperation - operations = [] - - # Slice - if info.dims: - operations.append(FO.SliceIdx) - - # Pin - ## PinLen1 - ## -> There must be a dimension with length 1. - if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]: - operations.append(FO.PinLen1) - - ## Pin | PinIdx - ## -> There must be a dimension, full stop. - if info.dims: - operations += [FO.Pin, FO.PinIdx] - - # Reinterpret - ## Swap - ## -> There must be at least two dimensions. - if len(info.dims) >= 2: # noqa: PLR2004 - operations.append(FO.Swap) - - return operations - - #################### - # - Computed Properties - #################### - @property - def func_args(self) -> list[sim_symbols.SimSymbol]: - FO = FilterOperation - return { - # Pin - FO.Pin: [sim_symbols.idx(None)], - FO.PinIdx: [sim_symbols.idx(None)], - }.get(self, []) - - #################### - # - Methods - #################### - @property - def num_dim_inputs(self) -> None: - FO = FilterOperation - return { - # Slice - FO.Slice: 1, - FO.SliceIdx: 1, - # Pin - FO.PinLen1: 1, - FO.Pin: 1, - FO.PinIdx: 1, - # Reinterpret - FO.Swap: 2, - }[self] - - def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: - FO = FilterOperation - match self: - # Slice - case FO.Slice: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - case FO.SliceIdx: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - # Pin - case FO.PinLen1: - return [ - dim - for dim, dim_idx in info.dims.items() - if not info.has_idx_cont(dim) and len(dim_idx) == 1 - ] - - case FO.Pin: - return info.dims - - case FO.PinIdx: - return [dim for dim in info.dims if not info.has_idx_cont(dim)] - - # Dimension - case FO.Swap: - return info.dims - - return [] - - def are_dims_valid( - self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None - ) -> bool: - """Check whether the given dimension inputs are valid in the context of this operation, and of the information.""" - if self.num_dim_inputs == 1: - return dim_0 in self.valid_dims(info) - - if self.num_dim_inputs == 2: # noqa: PLR2004 - valid_dims = self.valid_dims(info) - return dim_0 in valid_dims and dim_1 in valid_dims - - return False - - #################### - # - UI - #################### - def jax_func( - self, - axis_0: int | None, - axis_1: int | None, - slice_tuple: tuple[int, int, int] | None = None, - ): - FO = FilterOperation - return { - # Pin - FO.Slice: lambda expr: jlax.slice_in_dim( - expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 - ), - FO.SliceIdx: lambda expr: jlax.slice_in_dim( - expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 - ), - # Pin - FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0), - FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), - FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), - # Dimension - FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1), - }[self] - - def transform_info( - self, - info: ct.InfoFlow, - dim_0: sim_symbols.SimSymbol, - dim_1: sim_symbols.SimSymbol, - pin_idx: int | None = None, - slice_tuple: tuple[int, int, int] | None = None, - ): - FO = FilterOperation - return { - FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple), - FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), - # Pin - FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), - # Reinterpret - FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), - }[self]() - class FilterMathNode(base.MaxwellSimNode): r"""Applies a function that operates on the shape of the array. @@ -304,7 +94,7 @@ class FilterMathNode(base.MaxwellSimNode): #################### # - Properties: Operation #################### - operation: FilterOperation = bl_cache.BLField( + operation: math_system.FilterOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -313,7 +103,9 @@ class FilterMathNode(base.MaxwellSimNode): if self.expr_info is not None: return [ operation.bl_enum_element(i) - for i, operation in enumerate(FilterOperation.by_info(self.expr_info)) + for i, operation in enumerate( + math_system.FilterOperation.by_info(self.expr_info) + ) ] return [] @@ -358,7 +150,7 @@ class FilterMathNode(base.MaxwellSimNode): # - UI #################### def draw_label(self): - FO = FilterOperation + FO = math_system.FilterOperation match self.operation: # Slice case FO.SliceIdx: @@ -398,7 +190,7 @@ class FilterMathNode(base.MaxwellSimNode): row.prop(self, self.blfields['active_dim_0'], text='') row.prop(self, self.blfields['active_dim_1'], text='') - if self.operation is FilterOperation.SliceIdx: + if self.operation is math_system.FilterOperation.SliceIdx: layout.prop(self, self.blfields['slice_tuple'], text='') #################### @@ -434,7 +226,7 @@ class FilterMathNode(base.MaxwellSimNode): ## -> Works with continuous / discrete indexes. ## -> The user will be given a socket w/correct mathtype, unit, etc. . if ( - props['operation'] is FilterOperation.Pin + props['operation'] is math_system.FilterOperation.Pin and dim_0 is not None and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0)) ): @@ -460,7 +252,7 @@ class FilterMathNode(base.MaxwellSimNode): # Loose Sockets: Pin Dim by-Value ## -> Works with discrete points / labelled integers. elif ( - props['operation'] is FilterOperation.PinIdx + props['operation'] is math_system.FilterOperation.PinIdx and dim_0 is not None and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0)) ): @@ -594,7 +386,10 @@ class FilterMathNode(base.MaxwellSimNode): # Pin by-Value: Compute Nearest IDX ## -> Presume a sorted index array to be able to use binary search. - if props['operation'] is FilterOperation.Pin and has_pinned_value: + if ( + props['operation'] is math_system.FilterOperation.Pin + and has_pinned_value + ): nearest_idx_to_value = info.dims[dim_0].nearest_idx_of( pinned_value, require_sorted=True ) @@ -605,7 +400,10 @@ class FilterMathNode(base.MaxwellSimNode): ) # Pin by-Index - if props['operation'] is FilterOperation.PinIdx and has_pinned_axis: + if ( + props['operation'] is math_system.FilterOperation.PinIdx + and has_pinned_axis + ): return params.compose_within( enclosing_arg_targets=[sim_symbols.idx(None)], enclosing_func_args=[sp.S(pinned_axis)], diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 04765a8..4959822 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -16,363 +16,19 @@ """Declares `MapMathNode`.""" -import enum import typing as typ import bpy -import jax.numpy as jnp -import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, logger from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class MapOperation(enum.StrEnum): - """Valid operations for the `MapMathNode`. - - Attributes: - Real: Compute the real part of the input. - Imag: Compute the imaginary part of the input. - Abs: Compute the absolute value of the input. - Sq: Square the input. - Sqrt: Compute the (principal) square root of the input. - InvSqrt: Compute the inverse square root of the input. - Cos: Compute the cosine of the input. - Sin: Compute the sine of the input. - Tan: Compute the tangent of the input. - Acos: Compute the inverse cosine of the input. - Asin: Compute the inverse sine of the input. - Atan: Compute the inverse tangent of the input. - Norm2: Compute the 2-norm (aka. length) of the input vector. - Det: Compute the determinant of the input matrix. - Cond: Compute the condition number of the input matrix. - NormFro: Compute the frobenius norm of the input matrix. - Rank: Compute the rank of the input matrix. - Diag: Compute the diagonal vector of the input matrix. - EigVals: Compute the eigenvalues vector of the input matrix. - SvdVals: Compute the singular values vector of the input matrix. - Inv: Compute the inverse matrix of the input matrix. - Tra: Compute the transpose matrix of the input matrix. - Qr: Compute the QR-factorized matrices of the input matrix. - Chol: Compute the Cholesky-factorized matrices of the input matrix. - Svd: Compute the SVD-factorized matrices of the input matrix. - """ - - # By Number - Real = enum.auto() - Imag = enum.auto() - Abs = enum.auto() - Sq = enum.auto() - Sqrt = enum.auto() - InvSqrt = enum.auto() - Cos = enum.auto() - Sin = enum.auto() - Tan = enum.auto() - Acos = enum.auto() - Asin = enum.auto() - Atan = enum.auto() - Sinc = enum.auto() - # By Vector - Norm2 = enum.auto() - # By Matrix - Det = enum.auto() - Cond = enum.auto() - NormFro = enum.auto() - Rank = enum.auto() - Diag = enum.auto() - EigVals = enum.auto() - SvdVals = enum.auto() - Inv = enum.auto() - Tra = enum.auto() - Qr = enum.auto() - Chol = enum.auto() - Svd = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - MO = MapOperation - return { - # By Number - MO.Real: 'ℝ(v)', - MO.Imag: 'Im(v)', - MO.Abs: '|v|', - MO.Sq: 'v²', - MO.Sqrt: '√v', - MO.InvSqrt: '1/√v', - MO.Cos: 'cos v', - MO.Sin: 'sin v', - MO.Tan: 'tan v', - MO.Acos: 'acos v', - MO.Asin: 'asin v', - MO.Atan: 'atan v', - MO.Sinc: 'sinc v', - # By Vector - MO.Norm2: '||v||₂', - # By Matrix - MO.Det: 'det V', - MO.Cond: 'κ(V)', - MO.NormFro: '||V||_F', - MO.Rank: 'rank V', - MO.Diag: 'diag V', - MO.EigVals: 'eigvals V', - MO.SvdVals: 'svdvals V', - MO.Inv: 'V⁻¹', - MO.Tra: 'Vt', - MO.Qr: 'qr V', - MO.Chol: 'chol V', - MO.Svd: 'svd V', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - MO = MapOperation - return ( - str(self), - MO.to_name(self), - MO.to_name(self), - MO.to_icon(self), - i, - ) - - #################### - # - Ops from Shape - #################### - @staticmethod - def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]: - ## TODO: By info, not shape. - ## TODO: Check valid domains/mathtypes for some functions. - MO = MapOperation - element_ops = [ - MO.Real, - MO.Imag, - MO.Abs, - MO.Sq, - MO.Sqrt, - MO.InvSqrt, - MO.Cos, - MO.Sin, - MO.Tan, - MO.Acos, - MO.Asin, - MO.Atan, - MO.Sinc, - ] - - match (info.output.rows, info.output.cols): - case (1, 1): - return element_ops - - case (_, 1): - return [*element_ops, MO.Norm2] - - case (rows, cols) if rows == cols: - ## TODO: Check hermitian/posdef for cholesky. - ## - Can we even do this with just the output symbol approach? - return [ - *element_ops, - MO.Det, - MO.Cond, - MO.NormFro, - MO.Rank, - MO.Diag, - MO.EigVals, - MO.SvdVals, - MO.Inv, - MO.Tra, - MO.Qr, - MO.Chol, - MO.Svd, - ] - - case (rows, cols): - return [ - *element_ops, - MO.Cond, - MO.NormFro, - MO.Rank, - MO.SvdVals, - MO.Inv, - MO.Tra, - MO.Svd, - ] - - return [] - - #################### - # - Function Properties - #################### - @property - def sp_func(self): - MO = MapOperation - return { - # By Number - MO.Real: lambda expr: sp.re(expr), - MO.Imag: lambda expr: sp.im(expr), - MO.Abs: lambda expr: sp.Abs(expr), - MO.Sq: lambda expr: expr**2, - MO.Sqrt: lambda expr: sp.sqrt(expr), - MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr), - MO.Cos: lambda expr: sp.cos(expr), - MO.Sin: lambda expr: sp.sin(expr), - MO.Tan: lambda expr: sp.tan(expr), - MO.Acos: lambda expr: sp.acos(expr), - MO.Asin: lambda expr: sp.asin(expr), - MO.Atan: lambda expr: sp.atan(expr), - MO.Sinc: lambda expr: sp.sinc(expr), - # By Vector - # Vector -> Number - MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0], - # By Matrix - # Matrix -> Number - MO.Det: lambda expr: sp.det(expr), - MO.Cond: lambda expr: expr.condition_number(), - MO.NormFro: lambda expr: expr.norm(ord='fro'), - MO.Rank: lambda expr: expr.rank(), - # Matrix -> Vec - MO.Diag: lambda expr: expr.diagonal(), - MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())), - MO.SvdVals: lambda expr: expr.singular_values(), - # Matrix -> Matrix - MO.Inv: lambda expr: expr.inv(), - MO.Tra: lambda expr: expr.T, - # Matrix -> Matrices - MO.Qr: lambda expr: expr.QRdecomposition(), - MO.Chol: lambda expr: expr.cholesky(), - MO.Svd: lambda expr: expr.singular_value_decomposition(), - }[self] - - @property - def jax_func(self): - MO = MapOperation - return { - # By Number - MO.Real: lambda expr: jnp.real(expr), - MO.Imag: lambda expr: jnp.imag(expr), - MO.Abs: lambda expr: jnp.abs(expr), - MO.Sq: lambda expr: jnp.square(expr), - MO.Sqrt: lambda expr: jnp.sqrt(expr), - MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr), - MO.Cos: lambda expr: jnp.cos(expr), - MO.Sin: lambda expr: jnp.sin(expr), - MO.Tan: lambda expr: jnp.tan(expr), - MO.Acos: lambda expr: jnp.acos(expr), - MO.Asin: lambda expr: jnp.asin(expr), - MO.Atan: lambda expr: jnp.atan(expr), - MO.Sinc: lambda expr: jnp.sinc(expr), - # By Vector - # Vector -> Number - MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1), - # By Matrix - # Matrix -> Number - MO.Det: lambda expr: jnp.linalg.det(expr), - MO.Cond: lambda expr: jnp.linalg.cond(expr), - MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'), - MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr), - # Matrix -> Vec - MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1), - MO.EigVals: lambda expr: jnp.linalg.eigvals(expr), - MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr), - # Matrix -> Matrix - MO.Inv: lambda expr: jnp.linalg.inv(expr), - MO.Tra: lambda expr: jnp.matrix_transpose(expr), - # Matrix -> Matrices - MO.Qr: lambda expr: jnp.linalg.qr(expr), - MO.Chol: lambda expr: jnp.linalg.cholesky(expr), - MO.Svd: lambda expr: jnp.linalg.svd(expr), - }[self] - - def transform_info(self, info: ct.InfoFlow): - MO = MapOperation - - return { - # By Number - MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real), - MO.Sq: lambda: info, - MO.Sqrt: lambda: info, - MO.InvSqrt: lambda: info, - MO.Cos: lambda: info, - MO.Sin: lambda: info, - MO.Tan: lambda: info, - MO.Acos: lambda: info, - MO.Asin: lambda: info, - MO.Atan: lambda: info, - MO.Sinc: lambda: info, - # By Vector - MO.Norm2: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - # Interval - interval_finite_re=(0, sim_symbols.float_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - # By Matrix - MO.Det: lambda: info.update_output( - rows=1, - cols=1, - ), - MO.Cond: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - physical_type=spux.PhysicalType.NonPhysical, - unit=None, - ), - MO.NormFro: lambda: info.update_output( - mathtype=spux.MathType.Real, - rows=1, - cols=1, - # Interval - interval_finite_re=(0, sim_symbols.float_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - MO.Rank: lambda: info.update_output( - mathtype=spux.MathType.Integer, - rows=1, - cols=1, - physical_type=spux.PhysicalType.NonPhysical, - unit=None, - # Interval - interval_finite_re=(0, sim_symbols.int_max), - interval_inf=(False, True), - interval_closed=(True, False), - ), - # Matrix -> Vector ## TODO: ALL OF THESE - MO.Diag: lambda: info, - MO.EigVals: lambda: info, - MO.SvdVals: lambda: info, - # Matrix -> Matrix ## TODO: ALL OF THESE - MO.Inv: lambda: info, - MO.Tra: lambda: info, - # Matrix -> Matrices ## TODO: ALL OF THESE - MO.Qr: lambda: info, - MO.Chol: lambda: info, - MO.Svd: lambda: info, - }[self]() - - -#################### -# - Node -#################### class MapMathNode(base.MaxwellSimNode): r"""Applies a function by-structure to the data. @@ -495,7 +151,7 @@ class MapMathNode(base.MaxwellSimNode): return info return None - operation: MapOperation = bl_cache.BLField( + operation: math_system.MapOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -504,7 +160,9 @@ class MapMathNode(base.MaxwellSimNode): if self.expr_info is not None: return [ operation.bl_enum_element(i) - for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info)) + for i, operation in enumerate( + math_system.MapOperation.by_expr_info(self.expr_info) + ) ] return [] @@ -513,7 +171,7 @@ class MapMathNode(base.MaxwellSimNode): #################### def draw_label(self): if self.operation is not None: - return 'Map: ' + MapOperation.to_name(self.operation) + return 'Map: ' + math_system.MapOperation.to_name(self.operation) return self.bl_label diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py index f59bde4..920f863 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py @@ -14,354 +14,24 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import enum +"""Implements the `OperateMathNode`. + +See `blender_maxwell.maxwell_sim_nodes.math_system` for the actual mathematics implementation. +""" + import typing as typ import bpy -import jax.numpy as jnp -import sympy as sp -import sympy.physics.quantum as spq -import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class BinaryOperation(enum.StrEnum): - """Valid operations for the `OperateMathNode`. - - Attributes: - Mul: Scalar multiplication. - Div: Scalar division. - Pow: Scalar exponentiation. - Add: Elementwise addition. - Sub: Elementwise subtraction. - HadamMul: Elementwise multiplication (hadamard product). - HadamPow: Principled shape-aware exponentiation (hadamard power). - Atan2: Quadrant-respecting 2D arctangent. - VecVecDot: Dot product for identically shaped vectors w/transpose. - Cross: Cross product between identically shaped 3D vectors. - VecVecOuter: Vector-vector outer product. - LinSolve: Solve a linear system. - LsqSolve: Minimize error of an underdetermined linear system. - VecMatOuter: Vector-matrix outer product. - MatMatDot: Matrix-matrix dot product. - """ - - # Number | Number - Mul = enum.auto() - Div = enum.auto() - Pow = enum.auto() - - # Elements | Elements - Add = enum.auto() - Sub = enum.auto() - HadamMul = enum.auto() - # HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic. - Atan2 = enum.auto() - - # Vector | Vector - VecVecDot = enum.auto() - Cross = enum.auto() - VecVecOuter = enum.auto() - - # Matrix | Vector - LinSolve = enum.auto() - LsqSolve = enum.auto() - - # Vector | Matrix - VecMatOuter = enum.auto() - - # Matrix | Matrix - MatMatDot = enum.auto() - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - BO = BinaryOperation - return { - # Number | Number - BO.Mul: 'ℓ · r', - BO.Div: 'ℓ / r', - BO.Pow: 'ℓ ^ r', - # Elements | Elements - BO.Add: 'ℓ + r', - BO.Sub: 'ℓ - r', - BO.HadamMul: '𝐋 ⊙ 𝐑', - # BO.HadamPow: '𝐥 ⊙^ 𝐫', - BO.Atan2: 'atan2(ℓ:x, r:y)', - # Vector | Vector - BO.VecVecDot: '𝐥 · 𝐫', - BO.Cross: 'cross(𝐥,𝐫)', - BO.VecVecOuter: '𝐥 ⊗ 𝐫', - # Matrix | Vector - BO.LinSolve: '𝐋 ∖ 𝐫', - BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂', - # Vector | Matrix - BO.VecMatOuter: '𝐋 ⊗ 𝐫', - # Matrix | Matrix - BO.MatMatDot: '𝐋 · 𝐑', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - BO = BinaryOperation - return ( - str(self), - BO.to_name(self), - BO.to_name(self), - BO.to_icon(self), - i, - ) - - #################### - # - Ops from Shape - #################### - @staticmethod - def by_infos(info_l: int, info_r: int) -> list[typ.Self]: - """Deduce valid binary operations from the shapes of the inputs.""" - BO = BinaryOperation - - ops_el_el = [ - BO.Add, - BO.Sub, - BO.HadamMul, - # BO.HadamPow, - ] - - outl = info_l.output - outr = info_r.output - match (outl.shape_len, outr.shape_len): - # Number | * - ## Number | Number - case (0, 0): - ops = [ - BO.Add, - BO.Sub, - BO.Mul, - ] - - # Check Non-Zero Right Hand Side - ## -> Obviously, we can't ever divide by zero. - ## -> Sympy's assumptions system must always guarantee rhs != 0. - ## -> If it can't, then we simply don't expose division. - ## -> The is_zero assumption must be provided elsewhere. - ## -> NOTE: This may prevent some valid uses of division. - ## -> Watch out for "division is missing" bugs. - if info_r.output.is_nonzero: - ops.append(BO.Div) - - if ( - info_l.output.physical_type == spux.PhysicalType.Length - and info_l.output.unit == info_r.output.unit - ): - ops += [BO.Atan2] - - return [*ops, BO.Pow] - - ## Number | Vector - case (0, 1): - return [BO.Mul] # , BO.HadamPow] - - ## Number | Matrix - case (0, 2): - return [BO.Mul] # , BO.HadamPow] - - # Vector | * - ## Vector | Number - case (1, 0): - return [BO.Mul] # , BO.HadamPow] - - ## Vector | Vector - case (1, 1): - ops = [] - - # Vector | Vector - ## -> Dot: Convenience; utilize special vec-vec dot w/transp. - if outl.rows > outl.cols and outr.rows > outr.cols: - ops += [BO.VecVecDot, BO.VecVecOuter] - - # Covector | Vector - ## -> Dot: Directly use matrix-matrix dot, as it's now correct. - if outl.rows < outl.cols and outr.rows > outr.cols: - ops += [BO.MatMatDot, BO.VecVecOuter] - - # Vector | Covector - ## -> Dot: Directly use matrix-matrix dot, as it's now correct. - ## -> These are both the same operation, in this case. - if outl.rows > outl.cols and outr.rows < outr.cols: - ops += [BO.MatMatDot, BO.VecVecOuter] - - # Covector | Covector - ## -> Dot: Convenience; utilize special vec-vec dot w/transp. - if outl.rows < outl.cols and outr.rows < outr.cols: - ops += [BO.VecVecDot, BO.VecVecOuter] - - # Cross Product - ## -> Enforce that both are 3x1 or 1x3. - ## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross - if (outl.rows == 3 and outr.rows == 3) or ( - outl.cols == 3 and outl.cols == 3 - ): - ops += [BO.Cross] - - return ops_el_el + ops - - ## Vector | Matrix - case (1, 2): - return [BO.VecMatOuter] - - # Matrix | * - ## Matrix | Number - case (2, 0): - return [BO.Mul] # , BO.HadamPow] - - ## Matrix | Vector - case (2, 1): - prepend_ops = [] - - # Mat-Vec Dot: Enforce RHS Column Vector - if outr.rows > outl.cols: - prepend_ops += [BO.MatMatDot] - - return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow] - - ## Matrix | Matrix - case (2, 2): - return [*ops_el_el, BO.MatMatDot] - - return [] - - #################### - # - Function Properties - #################### - @property - def sp_func(self): - """Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs.""" - BO = BinaryOperation - - ## TODO: Make this compatible with sp.Matrix inputs - return { - # Number | Number - BO.Mul: lambda exprs: exprs[0] * exprs[1], - BO.Div: lambda exprs: exprs[0] / exprs[1], - BO.Pow: lambda exprs: exprs[0] ** exprs[1], - # Elements | Elements - BO.Add: lambda exprs: exprs[0] + exprs[1], - BO.Sub: lambda exprs: exprs[0] - exprs[1], - BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]), - # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), - BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), - # Vector | Vector - BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0], - BO.Cross: lambda exprs: exprs[0].cross(exprs[1]), - BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T, - # Matrix | Vector - BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]), - BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]), - # Vector | Matrix - BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]), - # Matrix | Matrix - BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1], - }[self] - - @property - def unit_func(self): - """The binary function to apply to both unit expressions, in order to deduce the unit expression of the output.""" - BO = BinaryOperation - - ## TODO: Make this compatible with sp.Matrix inputs - return { - # Number | Number - BO.Mul: BO.Mul.sp_func, - BO.Div: BO.Div.sp_func, - BO.Pow: BO.Pow.sp_func, - # Elements | Elements - BO.Add: BO.Add.sp_func, - BO.Sub: BO.Sub.sp_func, - BO.HadamMul: BO.Mul.sp_func, - # BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]), - BO.Atan2: lambda _: spu.radian, - # Vector | Vector - BO.VecVecDot: BO.Mul.sp_func, - BO.Cross: BO.Mul.sp_func, - BO.VecVecOuter: BO.Mul.sp_func, - # Matrix | Vector - ## -> A,b in Ax = b have units, and the equality must hold. - ## -> Therefore, A \ b must have the units [b]/[A]. - BO.LinSolve: lambda exprs: exprs[1] / exprs[0], - BO.LsqSolve: lambda exprs: exprs[1] / exprs[0], - # Vector | Matrix - BO.VecMatOuter: BO.Mul.sp_func, - # Matrix | Matrix - BO.MatMatDot: BO.Mul.sp_func, - }[self] - - @property - def jax_func(self): - """Deduce an appropriate jax-based function that implements the binary operation for array inputs.""" - ## TODO: Scale the units of one side to the other. - BO = BinaryOperation - - return { - # Number | Number - BO.Mul: lambda exprs: exprs[0] * exprs[1], - BO.Div: lambda exprs: exprs[0] / exprs[1], - BO.Pow: lambda exprs: exprs[0] ** exprs[1], - # Elements | Elements - BO.Add: lambda exprs: exprs[0] + exprs[1], - BO.Sub: lambda exprs: exprs[0] - exprs[1], - BO.HadamMul: lambda exprs: exprs[0] * exprs[1], - # BO.HadamPow: lambda exprs: exprs[0] ** exprs[1], - BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]), - # Vector | Vector - BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]), - BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]), - BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), - # Matrix | Vector - BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), - BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), - # Vector | Matrix - BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]), - # Matrix | Matrix - BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]), - }[self] - - #################### - # - InfoFlow Transform - #################### - def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): - """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression. - - Warnings: - `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r). - - If not, bad things will happen. - """ - return info_l.operate_output( - info_r, - lambda a, b: self.sp_func([a, b]), - lambda a, b: self.unit_func([a, b]), - ) - - -#################### -# - Node -#################### class OperateMathNode(base.MaxwellSimNode): r"""Applies a binary function between two expressions. @@ -386,7 +56,7 @@ class OperateMathNode(base.MaxwellSimNode): } #################### - # - Properties + # - Properties: Incoming InfoFlows #################### @events.on_value_changed( # Trigger @@ -415,6 +85,7 @@ class OperateMathNode(base.MaxwellSimNode): @bl_cache.cached_bl_property() def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None: + """Computed `InfoFlow`s of both expressions.""" info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info) info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info) @@ -426,19 +97,18 @@ class OperateMathNode(base.MaxwellSimNode): return None - operation: BinaryOperation = bl_cache.BLField( + #################### + # - Property: Operation + #################### + operation: math_system.BinaryOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_infos'}, ) def search_operations(self) -> list[ct.BLEnumElement]: + """Retrieve valid operations based on the input `InfoFlow`s.""" if self.expr_infos is not None: - return [ - operation.bl_enum_element(i) - for i, operation in enumerate( - BinaryOperation.by_infos(*self.expr_infos) - ) - ] + return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos) return [] #################### @@ -451,7 +121,7 @@ class OperateMathNode(base.MaxwellSimNode): Called by Blender to determine the text to place in the node's header. """ if self.operation is not None: - return 'Op: ' + BinaryOperation.to_name(self.operation) + return 'Op: ' + math_system.BinaryOperation.to_name(self.operation) return self.bl_label @@ -464,7 +134,7 @@ class OperateMathNode(base.MaxwellSimNode): layout.prop(self, self.blfields['operation'], text='') #################### - # - FlowKind.Value|Func + # - FlowKind.Value #################### @events.computes_output_socket( 'Expr', @@ -477,20 +147,22 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_value(self, props: dict, input_sockets: dict): - operation = props['operation'] + """Binary operation on two symbolic input expressions.""" expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] has_expr_l_value = not ct.FlowSignal.check(expr_l) has_expr_r_value = not ct.FlowSignal.check(expr_r) - # Compute Sympy Function - ## -> The operation enum directly provides the appropriate function. + operation = props['operation'] if has_expr_l_value and has_expr_r_value and operation is not None: return operation.sp_func([expr_l, expr_r]) return ct.FlowSignal.FlowPending + #################### + # - FlowKind.Func + #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Func, @@ -505,10 +177,7 @@ class OperateMathNode(base.MaxwellSimNode): output_socket_kinds={'Expr': ct.FlowKind.Info}, ) def compute_func(self, props, input_sockets, output_sockets): - operation = props['operation'] - if operation is None: - return ct.FlowSignal.FlowPending - + """Binary operation on two lazy-defined input expressions.""" expr_l = input_sockets['Expr L'] expr_r = input_sockets['Expr R'] output_info = output_sockets['Expr'] @@ -517,14 +186,9 @@ class OperateMathNode(base.MaxwellSimNode): has_expr_r = not ct.FlowSignal.check(expr_r) has_output_info = not ct.FlowSignal.check(output_info) - # Compute Jax Function - ## -> The operation enum directly provides the appropriate function. - if has_expr_l and has_expr_r and has_output_info: - return (expr_l | expr_r).compose_within( - operation.jax_func, - enclosing_func_output=output_info.output, - supports_jax=True, - ) + operation = props['operation'] + if operation is not None and has_expr_l and has_expr_r and has_output_info: + return self.operation.transform_funcs(expr_l, expr_r) return ct.FlowSignal.FlowPending #################### @@ -541,22 +205,17 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_info(self, props, input_sockets) -> ct.InfoFlow: - BO = BinaryOperation - - operation = props['operation'] + """Transform the input information of both lazy inputs.""" info_l = input_sockets['Expr L'] info_r = input_sockets['Expr R'] has_info_l = not ct.FlowSignal.check(info_l) has_info_r = not ct.FlowSignal.check(info_r) - # Compute Info - ## -> The operation enum directly provides the appropriate transform. + operation = props['operation'] if ( - has_info_l - and has_info_r - and operation is not None - and operation in BO.by_infos(info_l, info_r) + has_info_l and has_info_r and operation is not None + # and operation in BO.by_infos(info_l, info_r) ): return operation.transform_infos(info_l, info_r) @@ -576,15 +235,14 @@ class OperateMathNode(base.MaxwellSimNode): }, ) def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal: - operation = props['operation'] + """Merge the lazy input parameters.""" params_l = input_sockets['Expr L'] params_r = input_sockets['Expr R'] has_params_l = not ct.FlowSignal.check(params_l) has_params_r = not ct.FlowSignal.check(params_r) - # Compute Params - ## -> Operations don't add new parameters, so just concatenate L|R. + operation = props['operation'] if has_params_l and has_params_r and operation is not None: return params_l | params_r diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py index 6744f7d..aa6ed07 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py @@ -20,334 +20,18 @@ import enum import typing as typ import bpy -import jax.numpy as jnp -import jaxtyping as jtyp import sympy as sp -from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import bl_cache, logger, sim_symbols +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct -from .... import sockets +from .... import math_system, sockets from ... import base, events log = logger.get(__name__) -#################### -# - Operation Enum -#################### -class TransformOperation(enum.StrEnum): - """Valid operations for the `TransformMathNode`. - - Attributes: - FreqToVacWL: Transform an frequency dimension to vacuum wavelength. - VacWLToFreq: Transform a vacuum wavelength dimension to frequency. - ConvertIdxUnit: Convert the unit of a dimension to a compatible unit. - SetIdxUnit: Set all properties of a dimension. - FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it. - **For 2D integer-indexed data only**. - - IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type. - DimToVec: Fold the last dimension into the scalar output, creating a vector output type. - DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type. - FT: Compute the 1D fourier transform along a dimension. - New dimensional bounds are computing using the Nyquist Limit. - For higher dimensions, simply repeat along more dimensions. - InvFT1D: Compute the inverse 1D fourier transform along a dimension. - New dimensional bounds are computing using the Nyquist Limit. - For higher dimensions, simply repeat along more dimensions. - """ - - # Covariant Transform - FreqToVacWL = enum.auto() - VacWLToFreq = enum.auto() - ConvertIdxUnit = enum.auto() - SetIdxUnit = enum.auto() - FirstColToFirstIdx = enum.auto() - - # Fold - IntDimToComplex = enum.auto() - DimToVec = enum.auto() - DimsToMat = enum.auto() - - # Fourier - FT1D = enum.auto() - InvFT1D = enum.auto() - - # TODO: Affine - ## TODO - - #################### - # - UI - #################### - @staticmethod - def to_name(value: typ.Self) -> str: - TO = TransformOperation - return { - # Covariant Transform - TO.FreqToVacWL: '𝑓 → λᵥ', - TO.VacWLToFreq: 'λᵥ → 𝑓', - TO.ConvertIdxUnit: 'Convert Dim', - TO.SetIdxUnit: 'Set Dim', - TO.FirstColToFirstIdx: '1st Col → 1st Dim', - # Fold - TO.IntDimToComplex: '→ ℂ', - TO.DimToVec: '→ Vector', - TO.DimsToMat: '→ Matrix', - ## TODO: Vector to new last-dim integer - ## TODO: Matrix to two last-dim integers - # Fourier - TO.FT1D: 'FT', - TO.InvFT1D: 'iFT', - }[value] - - @property - def name(self) -> str: - return TransformOperation.to_name(self) - - @staticmethod - def to_icon(_: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - TO = TransformOperation - return ( - str(self), - TO.to_name(self), - TO.to_name(self), - TO.to_icon(self), - i, - ) - - #################### - # - Methods - #################### - def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: - TO = TransformOperation - match self: - case TO.FreqToVacWL: - return [ - dim - for dim in info.dims - if dim.physical_type is spux.PhysicalType.Freq - ] - - case TO.VacWLToFreq: - return [ - dim - for dim in info.dims - if dim.physical_type is spux.PhysicalType.Length - ] - - case TO.ConvertIdxUnit: - return [ - dim - for dim in info.dims - if not info.has_idx_labels(dim) - and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None - ] - - case TO.SetIdxUnit: - return [dim for dim in info.dims if not info.has_idx_labels(dim)] - - ## ColDimToComplex: Implicit Last Dimension - ## DimToVec: Implicit Last Dimension - ## DimsToMat: Implicit Last 2 Dimensions - - case TO.FT1D | TO.InvFT1D: - # Filter by Axis Uniformity - ## -> FT requires uniform axis (aka. must be RangeFlow). - ## -> NOTE: If FT isn't popping up, check ExtractDataNode. - return [dim for dim in info.dims if info.is_idx_uniform(dim)] - - return [] - - @staticmethod - def by_info(info: ct.InfoFlow) -> list[typ.Self]: - TO = TransformOperation - operations = [] - - # Covariant Transform - ## Freq -> VacWL - if TO.FreqToVacWL.valid_dims(info): - operations += [TO.FreqToVacWL] - - ## VacWL -> Freq - if TO.VacWLToFreq.valid_dims(info): - operations += [TO.VacWLToFreq] - - ## Convert Index Unit - if TO.ConvertIdxUnit.valid_dims(info): - operations += [TO.ConvertIdxUnit] - - if TO.SetIdxUnit.valid_dims(info): - operations += [TO.SetIdxUnit] - - ## Column to First Index (Array) - if ( - len(info.dims) == 2 # noqa: PLR2004 - and info.first_dim.mathtype is spux.MathType.Integer - and info.last_dim.mathtype is spux.MathType.Integer - and info.output.shape_len == 0 - ): - operations += [TO.FirstColToFirstIdx] - - # Fold - ## Last Dim -> Complex - if ( - len(info.dims) >= 1 - and ( - info.output.mathtype - in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real] - ) - and info.last_dim.mathtype is spux.MathType.Integer - and info.has_idx_labels(info.last_dim) - and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004 - ): - operations += [TO.IntDimToComplex] - - ## Last Dim -> Vector - if len(info.dims) >= 1 and info.output.shape_len == 0: - operations += [TO.DimToVec] - - ## Last Dim -> Matrix - if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004 - operations += [TO.DimsToMat] - - # Fourier - if TO.FT1D.valid_dims(info): - operations += [TO.FT1D] - - if TO.InvFT1D.valid_dims(info): - operations += [TO.InvFT1D] - - return operations - - #################### - # - Function Properties - #################### - def jax_func(self, axis: int | None = None): - TO = TransformOperation - return { - # Covariant Transform - ## -> Freq <-> WL is a rescale (noop) AND flip (not noop). - TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis), - TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis), - TO.ConvertIdxUnit: lambda expr: expr, - TO.SetIdxUnit: lambda expr: expr, - TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1), - # Fold - ## -> To Complex: This should generally be a no-op. - TO.IntDimToComplex: lambda expr: jnp.squeeze( - expr.view(dtype=jnp.complex64), axis=-1 - ), - TO.DimToVec: lambda expr: expr, - TO.DimsToMat: lambda expr: expr, - # Fourier - TO.FT1D: lambda expr: jnp.fft(expr, axis=axis), - TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis), - }[self] - - def transform_info( - self, - info: ct.InfoFlow, - dim: sim_symbols.SimSymbol | None = None, - data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None, - new_dim_name: str | None = None, - unit: spux.Unit | None = None, - physical_type: spux.PhysicalType | None = None, - ) -> ct.InfoFlow: - TO = TransformOperation - return { - # Covariant Transform - TO.FreqToVacWL: lambda: info.replace_dim( - (f_dim := dim), - sim_symbols.wl(unit), - info.dims[f_dim].rescale( - lambda el: sci_constants.vac_speed_of_light / el, - reverse=True, - new_unit=unit, - ), - ), - TO.VacWLToFreq: lambda: info.replace_dim( - (wl_dim := dim), - sim_symbols.freq(unit), - info.dims[wl_dim].rescale( - lambda el: sci_constants.vac_speed_of_light / el, - reverse=True, - new_unit=unit, - ), - ), - TO.ConvertIdxUnit: lambda: info.replace_dim( - dim, - dim.update(unit=unit), - ( - info.dims[dim].rescale_to_unit(unit) - if info.has_idx_discrete(dim) - else None ## Continuous -- dim SimSymbol already scaled - ), - ), - TO.SetIdxUnit: lambda: info.replace_dim( - dim, - dim.update( - sym_name=new_dim_name, - physical_type=physical_type, - unit=unit, - ), - ( - info.dims[dim].correct_unit(unit) - if info.has_idx_discrete(dim) - else None ## Continuous -- dim SimSymbol already scaled - ), - ), - TO.FirstColToFirstIdx: lambda: info.replace_dim( - info.first_dim, - info.first_dim.update( - sym_name=new_dim_name, - mathtype=spux.MathType.from_jax_array(data_col), - physical_type=physical_type, - unit=unit, - ), - ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)), - ).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)), - # Fold - TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( - mathtype=spux.MathType.Complex - ), - TO.DimToVec: lambda: info.fold_last_input(), - TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), - # Fourier - TO.FT1D: lambda: info.replace_dim( - dim, - [ - # FT'ed Unit: Reciprocal of the Original Unit - dim.update( - unit=1 / dim.unit if dim.unit is not None else 1 - ), ## TODO: Okay to not scale interval? - # FT'ed Bounds: Reciprocal of the Original Unit - info.dims[dim].bound_fourier_transform, - ], - ), - TO.InvFT1D: lambda: info.replace_dim( - info.last_dim, - [ - # FT'ed Unit: Reciprocal of the Original Unit - dim.update( - unit=1 / dim.unit if dim.unit is not None else 1 - ), ## TODO: Okay to not scale interval? - # FT'ed Bounds: Reciprocal of the Original Unit - ## -> Note the midpoint may revert to 0. - ## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more. - info.dims[dim].bound_inv_fourier_transform, - ], - ), - }[self]() - - -#################### -# - Node -#################### class TransformMathNode(base.MaxwellSimNode): r"""Applies a function to the array as a whole, with arbitrary results. @@ -409,7 +93,7 @@ class TransformMathNode(base.MaxwellSimNode): #################### # - Properties: Operation #################### - operation: TransformOperation = bl_cache.BLField( + operation: math_system.TransformOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), cb_depends_on={'expr_info'}, ) @@ -419,7 +103,7 @@ class TransformMathNode(base.MaxwellSimNode): return [ operation.bl_enum_element(i) for i, operation in enumerate( - TransformOperation.by_info(self.expr_info) + math_system.TransformOperation.by_info(self.expr_info) ) ] return [] @@ -461,7 +145,7 @@ class TransformMathNode(base.MaxwellSimNode): ) def search_units(self) -> list[ct.BLEnumElement]: - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: # Covariant Transform case TO.ConvertIdxUnit if self.dim is not None: @@ -521,7 +205,7 @@ class TransformMathNode(base.MaxwellSimNode): return spux.sp_to_str(self.new_unit) def draw_label(self): - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: case TO.FreqToVacWL if self.dim is not None: return f'T: {self.dim.name_pretty} | 𝑓 → {self.new_unit_str}' @@ -556,7 +240,7 @@ class TransformMathNode(base.MaxwellSimNode): def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: layout.prop(self, self.blfields['operation'], text='') - TO = TransformOperation + TO = math_system.TransformOperation match self.operation: case TO.ConvertIdxUnit: row = layout.row(align=True) @@ -613,7 +297,7 @@ class TransformMathNode(base.MaxwellSimNode): self, props, input_sockets, output_sockets ) -> ct.FuncFlow | ct.FlowSignal: """Transform the input `InfoFlow` depending on the transform operation.""" - TO = TransformOperation + TO = math_system.TransformOperation lazy_func = input_sockets['Expr'][ct.FlowKind.Func] info = input_sockets['Expr'][ct.FlowKind.Info] @@ -662,7 +346,7 @@ class TransformMathNode(base.MaxwellSimNode): self, props: dict, input_sockets: dict ) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]: """Transform the input `InfoFlow` depending on the transform operation.""" - TO = TransformOperation + TO = math_system.TransformOperation operation = props['operation'] info = input_sockets['Expr'][ct.FlowKind.Info] diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index ad44ef8..dda78ff 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -24,7 +24,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py index 2f8f276..f03eb77 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py @@ -22,7 +22,7 @@ import bpy import sympy as sp import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py index 91faec1..1490bdd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py @@ -22,7 +22,7 @@ import bpy import sympy as sp import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py index aa94c09..9a94de8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py @@ -19,7 +19,7 @@ import inspect import typing as typ from types import MappingProxyType -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .. import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py index dba473a..11068e3 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py @@ -21,7 +21,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py index db9df67..213746a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py @@ -16,12 +16,13 @@ import enum import typing as typ +from fractions import Fraction import bpy import sympy as sp from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets @@ -50,6 +51,8 @@ class SymbolConstantNode(base.MaxwellSimNode): ) size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar) + ## Use of NumberSize1D implicitly guarantees UI-realizability later. + mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real) physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) @@ -109,31 +112,110 @@ class SymbolConstantNode(base.MaxwellSimNode): preview_value_re: float = bl_cache.BLField(0.0) preview_value_im: float = bl_cache.BLField(0.0) - #################### - # - Computed Properties - #################### @bl_cache.cached_bl_property( depends_on={ - 'sym_name', - 'size', 'mathtype', - 'physical_type', - 'unit', 'interval_finite_z', 'interval_finite_q', 'interval_finite_re', - 'interval_inf', - 'interval_closed', 'interval_finite_im', - 'interval_inf_im', - 'interval_closed_im', + } + ) + def interval_finite( + self, + ) -> ( + tuple[int | Fraction | float, int | Fraction | float] + | tuple[tuple[float, float], tuple[float, float]] + ): + """Return the appropriate finite interval from the UI, as guided by `self.mathtype`.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer: + return self.interval_finite_z + case MT.Rational: + return [Fraction(*q) for q in self.interval_finite_q] + case MT.Real: + return self.interval_finite_re + case MT.Complex: + return (self.interval_finite_re, self.interval_finite_im) + + @bl_cache.cached_bl_property( + depends_on={ + 'mathtype', 'preview_value_z', 'preview_value_q', 'preview_value_re', 'preview_value_im', } ) + def preview_value( + self, + ) -> int | Fraction | float | complex: + """Return the appropriate finite interval from the UI, as guided by `self.mathtype`.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer: + return self.preview_value_z + case MT.Rational: + return Fraction(*self.preview_value_q) + case MT.Real: + return self.preview_value_re + case MT.Complex: + return complex(self.preview_value_re, self.preview_value_im) + + @bl_cache.cached_bl_property( + depends_on={ + 'mathtype', + 'interval_finite', + 'interval_inf', + 'interval_inf_im', + 'interval_closed', + 'interval_closed_im', + } + ) + def domain( + self, + ) -> sp.Interval | sp.sets.fancysets.CartesianComplexRegion: + """Deduce the domain specified in the UI.""" + MT = spux.MathType + match self.mathtype: + case MT.Integer | MT.Real | MT.Rational: + return sim_symbols.mk_interval( + self.interval_finite, + self.interval_inf, + self.interval_closed, + ) + + case MT.Complex: + region = self.interval_finite + domain_re = sim_symbols.mk_interval( + region[0], + self.interval_inf, + self.interval_closed, + ) + domain_im = sim_symbols.mk_interval( + region[1], + self.interval_inf_im, + self.interval_closed_im, + ) + return sp.ComplexRegion(domain_re, domain_im, polar=False) + + #################### + # - Computed Properties + #################### + @bl_cache.cached_bl_property( + depends_on={ + 'sym_name', + 'mathtype', + 'physical_type', + 'unit', + 'size', + 'domain', + 'preview_value', + } + ) def symbol(self) -> sim_symbols.SimSymbol: + """Generate the `SimSymbol` matching the user-specification.""" return sim_symbols.SimSymbol( sym_name=self.sym_name, mathtype=self.mathtype, @@ -141,18 +223,8 @@ class SymbolConstantNode(base.MaxwellSimNode): unit=self.unit, rows=self.size.rows, cols=self.size.cols, - interval_finite_z=self.interval_finite_z, - interval_finite_q=self.interval_finite_q, - interval_finite_re=self.interval_finite_re, - interval_inf=self.interval_inf, - interval_closed=self.interval_closed, - interval_finite_im=self.interval_finite_im, - interval_inf_im=self.interval_inf_im, - interval_closed_im=self.interval_closed_im, - preview_value_z=self.preview_value_z, - preview_value_q=self.preview_value_q, - preview_value_re=self.preview_value_re, - preview_value_im=self.preview_value_im, + domain=self.domain, + preview_value=self.preview_value, ) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py index 252dbed..e9134fc 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py @@ -23,7 +23,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py index 2acafa7..8b6fc43 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/scene.py @@ -23,7 +23,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py index 68d6d49..55f96c6 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py @@ -23,7 +23,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.utils import bl_cache, logger, sci_constants -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py index da00004..e9c275b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py @@ -25,7 +25,7 @@ from tidy3d.material_library.material_library import MaterialItem as Tidy3DMediu from tidy3d.material_library.material_library import VariantItem as Tidy3DMediumVariant from blender_maxwell.utils import bl_cache, logger, sci_constants -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py index 7d40ba2..d9ca2e8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py @@ -23,7 +23,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py index 319ce0b..d40e0ab 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py @@ -23,7 +23,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py index b5a287c..85d845b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py @@ -20,7 +20,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py index 4cf1503..963186b 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py @@ -20,7 +20,7 @@ from pathlib import Path import bpy from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py index 585793b..0ad5d02 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py @@ -21,7 +21,7 @@ import sympy as sp import tidy3d as td from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py index 0197f5f..7c525ff 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py @@ -22,7 +22,7 @@ import sympy as sp import sympy.physics.units as spu from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py index 2bbd6b0..94eb827 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py index d1d3ab5..3f60b9e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py index 0ed364b..0457790 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py @@ -22,7 +22,7 @@ import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py index 1f48f31..c307143 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py @@ -28,7 +28,7 @@ from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from ... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py index 27369d5..6669a50 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py @@ -20,7 +20,7 @@ import sympy as sp import sympy.physics.units as spu import tidy3d as td -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from ... import bl_socket_map, managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py index 2109185..754ac29 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py @@ -24,7 +24,7 @@ import tidy3d.plugins.adjoint as tdadj from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .... import contracts as ct from .... import managed_objs, sockets diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py index 072756d..1177dbd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py @@ -21,7 +21,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py index d5a83b4..8a51183 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py @@ -21,7 +21,7 @@ import sympy.physics.units as spu import tidy3d as td from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger from .... import contracts as ct diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index 84e077b..6c553a7 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -24,7 +24,7 @@ import pydantic as pyd import sympy as sp from blender_maxwell.utils import bl_cache, logger, sim_symbols -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from .. import contracts as ct from . import base @@ -155,12 +155,14 @@ class ExprBLSocket(base.MaxwellSimSocket): unit=self.unit, rows=self.size.rows, cols=self.size.cols, + is_constant=True, + ## TODO: Should we set preview values exclude_zero=( not self.value.is_zero if self.value.is_zero is not None else False ), - ## TODO: Does this work for matrix elements? + ## TODO: Does this 0-check work for matrix elements? ) case ct.FlowKind.Range if self.symbols: @@ -208,7 +210,7 @@ class ExprBLSocket(base.MaxwellSimSocket): sim_symbols.SimSymbolName.Expr ) output_name: sim_symbols.SimSymbolName = bl_cache.BLField( - sim_symbols.SimSymbolName.Expr + sim_symbols.SimSymbolName.Constant ) symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) @@ -441,12 +443,11 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False): """Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`.""" - if self.unit is not None: - pyvalue = spux.sympy_to_python(spux.scale_to_unit(expr, self.unit)) - else: - pyvalue = spux.sympy_to_python(expr) + pyvalue = spux.scale_to_unit(expr, self.unit) # Cast complex -> tuple[float, float] + ## -> We can't set complex to BLProps. + ## -> We must deconstruct it appropriately. if isinstance(pyvalue, complex) or ( isinstance(pyvalue, int | float) and force_complex ): diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py index 2c2119f..aa41351 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py @@ -24,7 +24,7 @@ import tidy3d as td import tidy3d.plugins.adjoint as tdadj from blender_maxwell.utils import bl_cache, logger -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from .. import base diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py index 7cf5cf0..a39fcbf 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/physical/pol.py @@ -19,7 +19,7 @@ import sympy as sp import sympy.physics.optics.polarization as spo_pol import sympy.physics.units as spu -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from ... import contracts as ct from .. import base diff --git a/src/blender_maxwell/utils/__init__.py b/src/blender_maxwell/utils/__init__.py index a44be30..e1e2d51 100644 --- a/src/blender_maxwell/utils/__init__.py +++ b/src/blender_maxwell/utils/__init__.py @@ -17,22 +17,22 @@ from ..nodeps.utils import blender_type_enum, pydeps from . import ( bl_cache, - extra_sympy_units, image_ops, logger, sci_constants, serialize, staticproperty, + sympy_extra, ) __all__ = [ 'blender_type_enum', 'pydeps', 'bl_cache', - 'extra_sympy_units', 'image_ops', 'logger', 'sci_constants', 'serialize', 'staticproperty', + 'sympy_extra', ] diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py deleted file mode 100644 index fa31709..0000000 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ /dev/null @@ -1,1699 +0,0 @@ -# blender_maxwell -# Copyright (C) 2024 blender_maxwell Project Contributors -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -"""Declares useful sympy units and functions, to make it easier to work with `sympy` as the basis for a unit-aware system. - -Attributes: - UNIT_BY_SYMBOL: Maps all abbreviated Sympy symbols to their corresponding Sympy unit. - This is essential for parsing string expressions that use units, since a pure parse of ex. `a*m + m` would not otherwise be able to differentiate between `sp.Symbol(m)` and `spu.meter`. - SympyType: A simple union of valid `sympy` types, used to check whether arbitrary objects should be handled using `sympy` functions. - For simple `isinstance` checks, this should be preferred, as it is most performant. - For general use, `SympyExpr` should be preferred. - SympyExpr: A `SympyType` that is compatible with `pydantic`, including serialization/deserialization. - Should be used via the `ConstrSympyExpr`, which also adds expression validation. -""" - -import enum -import functools -import sys -import typing as typ -from fractions import Fraction - -import jax -import jax.numpy as jnp -import jaxtyping as jtyp -import pydantic as pyd -import sympy as sp -import sympy.physics.units as spu -import typing_extensions as typx -from pydantic_core import core_schema as pyd_core_schema - -from blender_maxwell import contracts as ct - -from . import logger -from .staticproperty import staticproperty - -log = logger.get(__name__) - -SympyType = ( - sp.Basic - | sp.Expr - | sp.MatrixBase - | sp.MutableDenseMatrix - | spu.Quantity - | spu.Dimension -) - - -#################### -# - Math Type -#################### -class MathType(enum.StrEnum): - """Type identifiers that encompass common sets of mathematical objects.""" - - Integer = enum.auto() - Rational = enum.auto() - Real = enum.auto() - Complex = enum.auto() - - @staticmethod - def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None: - if MathType.Complex in mathtypes: - return MathType.Complex - if MathType.Real in mathtypes: - return MathType.Real - if MathType.Rational in mathtypes: - return MathType.Rational - if MathType.Integer in mathtypes: - return MathType.Integer - - if optional: - return None - - msg = f"Can't combine mathtypes {mathtypes}" - raise ValueError(msg) - - def is_compatible(self, other: typ.Self) -> bool: - MT = MathType - return ( - other - in { - MT.Integer: [MT.Integer], - MT.Rational: [MT.Integer, MT.Rational], - MT.Real: [MT.Integer, MT.Rational, MT.Real], - MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex], - }[self] - ) - - def coerce_compatible_pyobj( - self, pyobj: bool | int | Fraction | float | complex - ) -> int | Fraction | float | complex: - MT = MathType - match self: - case MT.Integer: - return int(pyobj) - case MT.Rational if isinstance(pyobj, int): - return Fraction(pyobj, 1) - case MT.Rational if isinstance(pyobj, Fraction): - return pyobj - case MT.Real: - return float(pyobj) - case MT.Complex if isinstance(pyobj, int | Fraction): - return complex(float(pyobj), 0) - case MT.Complex if isinstance(pyobj, float): - return complex(pyobj, 0) - - @staticmethod - def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: - if isinstance(sp_obj, sp.MatrixBase): - return MathType.combine( - *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] - ) - - if sp_obj.is_integer: - return MathType.Integer - if sp_obj.is_rational: - return MathType.Rational - if sp_obj.is_real: - return MathType.Real - if sp_obj.is_complex: - return MathType.Complex - - # Infinities - if sp_obj in [sp.oo, -sp.oo]: - return MathType.Real ## TODO: Strictly, could be ex. integer... - if sp_obj in [sp.zoo, -sp.zoo]: - return MathType.Complex - - if optional: - return None - - msg = f"Can't determine MathType from sympy object: {sp_obj}" - raise ValueError(msg) - - @staticmethod - def from_pytype(dtype: type) -> type: - return { - int: MathType.Integer, - Fraction: MathType.Rational, - float: MathType.Real, - complex: MathType.Complex, - }[dtype] - - @staticmethod - def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type: - """Deduce the MathType corresponding to a JAX array. - - We go about this by leveraging that: - - `data` is of a homogeneous type. - - `data.item(0)` returns a single element of the array w/pure-python type. - - By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency. - - Notes: - Should also work with numpy arrays. - """ - return MathType.from_pytype(type(data.item(0))) - - @staticmethod - def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None: - if isinstance(obj, bool | int | Fraction | float | complex): - return 'pytype' - if isinstance(obj, sp.Basic | sp.MatrixBase | sp.MutableDenseMatrix): - return 'expr' - - return None - - @property - def pytype(self) -> type: - MT = MathType - return { - MT.Integer: int, - MT.Rational: Fraction, - MT.Real: float, - MT.Complex: complex, - }[self] - - @property - def symbolic_set(self) -> type: - MT = MathType - return { - MT.Integer: sp.Integers, - MT.Rational: sp.Rationals, - MT.Real: sp.Reals, - MT.Complex: sp.Complexes, - }[self] - - @property - def inf_finite(self) -> type: - """Opinionated finite representation of "infinity" within this `MathType`. - - These are chosen using `sys.maxsize` and `sys.float_info`. - As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated. - - **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here. - - Notes: - Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. . - - These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`). - In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway. - - However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context. - """ - MT = MathType - Z = MT.Integer - R = MT.Integer - return { - MT.Integer: (-sys.maxsize, sys.maxsize), - MT.Rational: ( - Fraction(Z.inf_finite[0], 1), - Fraction(Z.inf_finite[1], 1), - ), - MT.Real: -(sys.float_info.min, sys.float_info.max), - MT.Complex: ( - complex(R.inf_finite[0], R.inf_finite[0]), - complex(R.inf_finite[1], R.inf_finite[1]), - ), - }[self] - - @property - def sp_symbol_a(self) -> type: - MT = MathType - return { - MT.Integer: sp.Symbol('a', integer=True), - MT.Rational: sp.Symbol('a', rational=True), - MT.Real: sp.Symbol('a', real=True), - MT.Complex: sp.Symbol('a', complex=True), - }[self] - - @staticmethod - def to_str(value: typ.Self) -> type: - return { - MathType.Integer: 'ℤ', - MathType.Rational: 'ℚ', - MathType.Real: 'ℝ', - MathType.Complex: 'ℂ', - }[value] - - @property - def label_pretty(self) -> str: - return MathType.to_str(self) - - @staticmethod - def to_name(value: typ.Self) -> str: - return MathType.to_str(value) - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - return ( - str(self), - MathType.to_name(self), - MathType.to_name(self), - MathType.to_icon(self), - i, - ) - - -#################### -# - Size: 1D -#################### -class NumberSize1D(enum.StrEnum): - """Valid 1D-constrained shape.""" - - Scalar = enum.auto() - Vec2 = enum.auto() - Vec3 = enum.auto() - Vec4 = enum.auto() - - @staticmethod - def to_name(value: typ.Self) -> str: - NS = NumberSize1D - return { - NS.Scalar: 'Scalar', - NS.Vec2: '2D', - NS.Vec3: '3D', - NS.Vec4: '4D', - }[value] - - @staticmethod - def to_icon(value: typ.Self) -> str: - NS = NumberSize1D - return { - NS.Scalar: '', - NS.Vec2: '', - NS.Vec3: '', - NS.Vec4: '', - }[value] - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - return ( - str(self), - NumberSize1D.to_name(self), - NumberSize1D.to_name(self), - NumberSize1D.to_icon(self), - i, - ) - - @staticmethod - def has_shape(shape: tuple[int, ...] | None): - return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)] - - def supports_shape(self, shape: tuple[int, ...] | None): - NS = NumberSize1D - match self: - case NS.Scalar: - return shape is None - case NS.Vec2: - return shape in ((2,), (2, 1)) - case NS.Vec3: - return shape in ((3,), (3, 1)) - case NS.Vec4: - return shape in ((4,), (4, 1)) - - @staticmethod - def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: - NS = NumberSize1D - return { - None: NS.Scalar, - (2,): NS.Vec2, - (3,): NS.Vec3, - (4,): NS.Vec4, - (2, 1): NS.Vec2, - (3, 1): NS.Vec3, - (4, 1): NS.Vec4, - }[shape] - - @property - def rows(self): - NS = NumberSize1D - return { - NS.Scalar: 1, - NS.Vec2: 2, - NS.Vec3: 3, - NS.Vec4: 4, - }[self] - - @property - def cols(self): - return 1 - - @property - def shape(self): - NS = NumberSize1D - return { - NS.Scalar: None, - NS.Vec2: (2,), - NS.Vec3: (3,), - NS.Vec4: (4,), - }[self] - - -def symbol_range(sym: sp.Symbol) -> str: - return f'{sym.name} ∈ ' + ( - 'ℂ' - if sym.is_complex - else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) - ) - - -#################### -# - Symbol Sizes -#################### -class SimpleSize2D(enum.StrEnum): - """Simple subset of sizes for rank-2 tensors.""" - - Scalar = enum.auto() - - # Vectors - Vec2 = enum.auto() ## 2x1 - Vec3 = enum.auto() ## 3x1 - Vec4 = enum.auto() ## 4x1 - - # Covectors - CoVec2 = enum.auto() ## 1x2 - CoVec3 = enum.auto() ## 1x3 - CoVec4 = enum.auto() ## 1x4 - - # Square Matrices - Mat22 = enum.auto() ## 2x2 - Mat33 = enum.auto() ## 3x3 - Mat44 = enum.auto() ## 4x4 - - -#################### -# - Unit Dimensions -#################### -class DimsMeta(type): - def __getattr__(cls, attr: str) -> spu.Dimension: - if ( - attr in spu.definitions.dimension_definitions.__dir__() - and not attr.startswith('__') - ): - return getattr(spu.definitions.dimension_definitions, attr) - - raise AttributeError(name=attr, obj=Dims) - - -class Dims(metaclass=DimsMeta): - """Access `sympy.physics.units` dimensions with less hassle. - - Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. - - An `AttributeError` is raised if the unit cannot be found in `sympy`. - - Examples: - The objects returned are a direct alias to `sympy`, with less hassle: - ```python - assert Dims.length == ( - sympy.physics.units.definitions.dimension_definitions.length - ) - ``` - """ - - -#################### -# - Units -#################### -femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs') -femtosecond.set_global_relative_scale_factor(spu.femto, spu.second) - -# Length -femtometer = fm = spu.Quantity('femtometer', abbrev='fm') -femtometer.set_global_relative_scale_factor(spu.femto, spu.meter) - -# Lum Flux -lumen = lm = spu.Quantity('lumen', abbrev='lm') -lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian) - -# Force -nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816 -nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton) - -micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816 -micronewton.set_global_relative_scale_factor(spu.micro, spu.newton) - -millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816 -micronewton.set_global_relative_scale_factor(spu.milli, spu.newton) - -# Frequency -kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz') -kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) - -megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz') -kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) - -gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz') -gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz) - -terahertz = THz = spu.Quantity('terahertz', abbrev='THz') -terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz) - -petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz') -petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) - -exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') -exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) - -# Pressure -millibar = mbar = spu.Quantity('millibar', abbrev='mbar') -millibar.set_global_relative_scale_factor(spu.milli, spu.bar) - -hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 -hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) - -UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { - unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) -} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} - -UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()} - - -#################### -# - Expr Analysis: Units -#################### -## TODO: Caching w/srepr'ed expression. -## TODO: An LFU cache could do better than an LRU. -def uses_units(sp_obj: SympyType) -> bool: - """Determines if an expression uses any units. - - Notes: - The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. - Depth-first was chosen since `sp.Quantity`s are likelier to be found among individual symbols, rather than complete subexpressions. - - The **worst-case** runtime is when there are no units, in which case the **entire expression graph will be traversed**. - - Parameters: - expr: The sympy expression that may contain units. - - Returns: - Whether or not there are units used within the expression. - """ - return sp_obj.has(spu.Quantity) - # return any( - # isinstance(subexpr, spu.Quantity) for subexpr in sp.postorder_traversal(sp_obj) - # ) - - -## TODO: Caching w/srepr'ed expression. -## TODO: An LFU cache could do better than an LRU. -def get_units(expr: sp.Expr) -> set[spu.Quantity]: - """Finds all units used by the expression, and returns them as a set. - - No information about _the relationship between units_ is exposed. - For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`. - - - Notes: - The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. - - The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node. - - Parameters: - expr: The sympy expression that may contain units. - - Returns: - All units (`spu.Quantity`) used within the expression. - """ - return { - subexpr - for subexpr in sp.postorder_traversal(expr) - if isinstance(subexpr, spu.Quantity) - } - - -def parse_shape(sp_obj: SympyType) -> int | None: - if isinstance(sp_obj, sp.MatrixBase): - return sp_obj.shape - - return None - - -#################### -# - Pydantic-Validated SympyExpr -#################### -class _SympyExpr: - """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`. - - Notes: - You probably want to use `SympyExpr`. - - Examples: - To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`: - - ```python - SympyExpr = typx.Annotated[SympyType, _SympyExpr] - - class Spam(pyd.BaseModel): - line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3) - ``` - """ - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: SympyType, - _handler: pyd.GetCoreSchemaHandler, - ) -> pyd_core_schema.CoreSchema: - """Compute a schema that allows `pydantic` to validate a `sympy` type.""" - - def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any: - """Parse and validate a string expression. - - Parameters: - sp_str: A stringified `sympy` object, that will be parsed to a sympy type. - Before use, `isinstance(expr_str, str)` is checked. - If the object isn't a string, then the validation will be skipped. - - Returns: - Either a `sympy` object, if the input is parseable, or the same untouched object. - - Raises: - ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. - """ - # Constrain to String - if not isinstance(sp_str, str): - return sp_str - - # Parse String -> Sympy - try: - expr = sp.sympify(sp_str) - except ValueError as ex: - msg = f'String {sp_str} is not a valid sympy expression' - raise ValueError(msg) from ex - - # Substitute Symbol -> Quantity - return expr.subs(UNIT_BY_SYMBOL) - - def validate_from_pytype( - sp_pytype: int | Fraction | float | complex, - ) -> SympyType | typ.Any: - """Parse and validate a pure Python type. - - Parameters: - sp_str: A stringified `sympy` object, that will be parsed to a sympy type. - Before use, `isinstance(expr_str, str)` is checked. - If the object isn't a string, then the validation will be skipped. - - Returns: - Either a `sympy` object, if the input is parseable, or the same untouched object. - - Raises: - ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. - """ - # Constrain to String - if not isinstance(sp_pytype, int | Fraction | float | complex): - return sp_pytype - - if isinstance(sp_pytype, int): - return sp.Integer(sp_pytype) - if isinstance(sp_pytype, Fraction): - return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) - if isinstance(sp_pytype, float): - return sp.Float(sp_pytype) - - # sp_pytype => Complex - return sp_pytype.real + sp.I * sp_pytype.imag - - sympy_expr_schema = pyd_core_schema.chain_schema( - [ - pyd_core_schema.no_info_plain_validator_function(validate_from_str), - pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), - pyd_core_schema.is_instance_schema(SympyType), - ] - ) - return pyd_core_schema.json_or_python_schema( - json_schema=sympy_expr_schema, - python_schema=sympy_expr_schema, - serialization=pyd_core_schema.plain_serializer_function_ser_schema( - lambda sp_obj: sp.srepr(sp_obj) - ), - ) - - -SympyExpr = typx.Annotated[ - sp.Basic, ## Treat all sympy types as sp.Basic - _SympyExpr, -] -## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. - - -def ConstrSympyExpr( # noqa: N802, PLR0913 - # Features - allow_variables: bool = True, - allow_units: bool = True, - # Structures - allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] - | None = None, - allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, - # Element Class - max_symbols: int | None = None, - allowed_symbols: set[sp.Symbol] | None = None, - allowed_units: set[spu.Quantity] | None = None, - # Shape Class - allowed_matrix_shapes: set[tuple[int, int]] | None = None, -) -> SympyType: - """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`. - - Relies on the `sympy` assumptions system. - See - - Parameters (TBD): - - Returns: - A type that represents a constrained `sympy` expression. - """ - - def validate_expr(expr: SympyType): - if not (isinstance(expr, SympyType),): - msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})" - raise ValueError(msg) - - msgs = set() - - # Validate Feature Class - if (not allow_variables) and (len(expr.free_symbols) > 0): - msgs.add( - f'allow_variables={allow_variables} does not match expression {expr}.' - ) - if (not allow_units) and uses_units(expr): - msgs.add(f'allow_units={allow_units} does not match expression {expr}.') - - # Validate Structure Class - if ( - allowed_sets - and isinstance(expr, sp.Expr) - and not any( - { - 'integer': expr.is_integer, - 'rational': expr.is_rational, - 'real': expr.is_real, - 'complex': expr.is_complex, - }[allowed_set] - for allowed_set in allowed_sets - ) - ): - msgs.add( - f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - if allowed_structures and not any( - { - 'scalar': True, - 'matrix': isinstance(expr, sp.MatrixBase), - }[allowed_set] - for allowed_set in allowed_structures - ): - msgs.add( - f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" - ) - - # Validate Element Class - if max_symbols and len(expr.free_symbols) > max_symbols: - msgs.add(f'max_symbols={max_symbols} does not match expression {expr}') - if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): - msgs.add( - f'allowed_symbols={allowed_symbols} does not match expression {expr}' - ) - if allowed_units and get_units(expr).issubset(allowed_units): - msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') - - # Validate Shape Class - if ( - allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) - ) and expr.shape not in allowed_matrix_shapes: - msgs.add( - f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' - ) - - # Error or Return - if msgs: - raise ValueError(str(msgs)) - return expr - - return typx.Annotated[ - sp.Basic, - _SympyExpr, - pyd.AfterValidator(validate_expr), - ] - - -#################### -# - Common ConstrSympyExpr -#################### -# Expression -ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_structures={'scalar'}, - allowed_sets={'integer', 'rational', 'real'}, -) -ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_structures={'scalar'}, - allowed_sets={'integer', 'rational', 'real', 'complex'}, -) - -# Symbol -IntSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer'}, - max_symbols=1, -) -RationalSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational'}, - max_symbols=1, -) -RealSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - max_symbols=1, -) -ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( - allow_variables=True, - allow_units=False, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - max_symbols=1, -) -Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol - -# Unit -UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension - -## Technically a "unit expression", which includes compound types. -## Support for this is the reason to prefer over raw spu.Quantity. -Unit: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_structures={'scalar'}, -) - -# Number -IntNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer'}, - allowed_structures={'scalar'}, -) -RealNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'scalar'}, -) -ComplexNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - allowed_structures={'scalar'}, -) -Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber - -# Number -PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'scalar'}, -) -PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=True, - allowed_sets={'integer', 'rational', 'real', 'complex'}, - allowed_structures={'scalar'}, -) -PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber - -# Vector -Real3DVector: typ.TypeAlias = ConstrSympyExpr( - allow_variables=False, - allow_units=False, - allowed_sets={'integer', 'rational', 'real'}, - allowed_structures={'matrix'}, - allowed_matrix_shapes={(3, 1)}, -) - - -#################### -# - Sympy Utilities: Printing -#################### -_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( - settings={ - 'abbrev': True, - } -) - - -def sp_to_str(sp_obj: SympyExpr) -> str: - """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. - - This should be used whenever a **string for UI use** is needed from a `sympy` object. - - Notes: - This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. - For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. - - Parameters: - sp_obj: The `sympy` object to convert to a string. - - Returns: - A string representing the expression for human use. - _The string is not re-encodable to the expression._ - """ - ## TODO: A bool flag property that does a lot of find/replace to make it super pretty - return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) - - -def pretty_symbol(sym: sp.Symbol) -> str: - return f'{sym.name} ∈ ' + ( - 'ℤ' - if sym.is_integer - else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?')) - ) - - -#################### -# - Unit Utilities -#################### -def scale_to_unit(sp_obj: SympyType, unit: spu.Quantity) -> Number: - """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. - - This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. - - Notes: - The unitless output is still an `sp.Expr`, which may contain ex. symbols. - - If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. - In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. - - Parameters: - expr: The unit-containing expression to convert. - unit_to: The unit that is converted to. - - Returns: - The unitless part of `expr`, after scaling the entire expression to `unit`. - - Raises: - ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. - """ - unitless_expr = spu.convert_to(sp_obj, unit) / unit if unit is not None else sp_obj - if not uses_units(unitless_expr): - return unitless_expr - - msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' - raise ValueError(msg) - - -def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number: - """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. - - Parameters: - unit_from: The unit that is converted from. - unit_to: The unit that is converted to. - - Returns: - The numerical scaling factor between the two units. - - Raises: - ValueError: If the two units don't share a common dimension. - """ - if unit_from.dimension == unit_to.dimension: - return scale_to_unit(unit_from, unit_to) - - msg = f"Dimension of unit_from={unit_from} ({unit_from.dimension}) doesn't match the dimension of unit_to={unit_to} ({unit_to.dimension}); therefore, there is no scaling factor between them" - raise ValueError(msg) - - -@functools.cache -def unit_str_to_unit(unit_str: str) -> Unit | None: - # Edge Case: Manually Parse Degrees - ## -> sp.sympify('degree') actually produces the sp.degree() function. - ## -> Therefore, we must special case this particular unit. - if unit_str == 'degree': - expr = spu.degree - else: - expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL) - - if expr.has(spu.Quantity): - return expr - - msg = f'No valid unit for unit string {unit_str}' - raise ValueError(msg) - - -#################### -# - "Physical" Type -#################### -def unit_dim_to_unit_dim_deps( - unit_dims: SympyType, -) -> dict[spu.dimensions.Dimension, int] | None: - dimsys_SI = spu.systems.si.dimsys_SI - - # Retrieve Dimensional Dependencies - try: - return dimsys_SI.get_dimensional_dependencies(unit_dims) - - # Catch TypeError - ## -> Happens if `+` or `-` is in `unit`. - ## -> Generally, it doesn't make sense to add/subtract differing unit dims. - ## -> Thus, when trying to figure out the unit dimension, there isn't one. - except TypeError: - return None - - -def unit_to_unit_dim_deps( - unit: SympyType, -) -> dict[spu.dimensions.Dimension, int] | None: - # Retrieve Dimensional Dependencies - ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms. - ## -> This is extremely problematic; `Dims` arithmetic has key properties. - ## -> So we have to go all the way to the dimensional dependencies. - ## -> This isn't really respecting the args, but it seems to work :) - return unit_dim_to_unit_dim_deps( - unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)}) - ) - - -def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool: - return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps( - unit_dim_r - ) - - -def compare_unit_dim_to_unit_dim_deps( - unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int] -) -> bool: - return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps - - -class PhysicalType(enum.StrEnum): - """Type identifiers for expressions with both `MathType` and a unit, aka a "physical" type.""" - - # Unitless - NonPhysical = enum.auto() - - # Global - Time = enum.auto() - Angle = enum.auto() - SolidAngle = enum.auto() - ## TODO: Some kind of 3D-specific orientation ex. a quaternion - Freq = enum.auto() - AngFreq = enum.auto() ## rad*hertz - # Cartesian - Length = enum.auto() - Area = enum.auto() - Volume = enum.auto() - # Mechanical - Vel = enum.auto() - Accel = enum.auto() - Mass = enum.auto() - Force = enum.auto() - Pressure = enum.auto() - # Energy - Work = enum.auto() ## joule - Power = enum.auto() ## watt - PowerFlux = enum.auto() ## watt - Temp = enum.auto() - # Electrodynamics - Current = enum.auto() ## ampere - CurrentDensity = enum.auto() - Charge = enum.auto() ## coulomb - Voltage = enum.auto() - Capacitance = enum.auto() ## farad - Impedance = enum.auto() ## ohm - Conductance = enum.auto() ## siemens - Conductivity = enum.auto() ## siemens / length - MFlux = enum.auto() ## weber - MFluxDensity = enum.auto() ## tesla - Inductance = enum.auto() ## henry - EField = enum.auto() - HField = enum.auto() - # Luminal - LumIntensity = enum.auto() - LumFlux = enum.auto() - Illuminance = enum.auto() - - @functools.cached_property - def unit_dim(self) -> SympyType: - PT = PhysicalType - return { - PT.NonPhysical: None, - # Global - PT.Time: Dims.time, - PT.Angle: Dims.angle, - PT.SolidAngle: spu.steradian.dimension, ## MISSING - PT.Freq: Dims.frequency, - PT.AngFreq: Dims.angle * Dims.frequency, - # Cartesian - PT.Length: Dims.length, - PT.Area: Dims.length**2, - PT.Volume: Dims.length**3, - # Mechanical - PT.Vel: Dims.length / Dims.time, - PT.Accel: Dims.length / Dims.time**2, - PT.Mass: Dims.mass, - PT.Force: Dims.force, - PT.Pressure: Dims.pressure, - # Energy - PT.Work: Dims.energy, - PT.Power: Dims.power, - PT.PowerFlux: Dims.power / Dims.length**2, - PT.Temp: Dims.temperature, - # Electrodynamics - PT.Current: Dims.current, - PT.CurrentDensity: Dims.current / Dims.length**2, - PT.Charge: Dims.charge, - PT.Voltage: Dims.voltage, - PT.Capacitance: Dims.capacitance, - PT.Impedance: Dims.impedance, - PT.Conductance: Dims.conductance, - PT.Conductivity: Dims.conductance / Dims.length, - PT.MFlux: Dims.magnetic_flux, - PT.MFluxDensity: Dims.magnetic_density, - PT.Inductance: Dims.inductance, - PT.EField: Dims.voltage / Dims.length, - PT.HField: Dims.current / Dims.length, - # Luminal - PT.LumIntensity: Dims.luminous_intensity, - PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, - PT.Illuminance: Dims.luminous_intensity / Dims.length**2, - }[self] - - @staticproperty - def unit_dims() -> dict[typ.Self, SympyType]: - return { - physical_type: physical_type.unit_dim - for physical_type in list(PhysicalType) - } - - @functools.cached_property - def color(self): - """A color corresponding to the physical type. - - The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented. - The LLM provided the following rationale for its choices: - - > Non-Physical: Grey signifies neutrality and non-physical nature. - > Global: - > Time: Blue is often associated with calmness and the passage of time. - > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects. - > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time. - > Cartesian: - > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension. - > Mechanical: - > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities. - > Mass: Dark red for the fundamental property. - > Force and Pressure: Shades of red indicating intensity. - > Energy: - > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities. - > Temperature: Yellow for heat. - > Electrodynamics: - > Current and related quantities: Cyan shades indicating flow. - > Voltage, Capacitance: Greenish and blueish cyan for electrical potential. - > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance. - > Magnetic properties: Magenta shades for magnetism. - > Electric Field: Light blue. - > Magnetic Field: Grey, as it can be considered neutral in terms of direction. - > Luminal: - > Luminous properties: Yellows to signify light and illumination. - > - > This color mapping helps maintain intuitive connections for users interacting with these physical types. - """ - PT = PhysicalType - return { - PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical - # Global - PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time - PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle - PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle - PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency - PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency - # Cartesian - PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length - PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area - PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume - # Mechanical - PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity - PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration - PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass - PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force - PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure - # Energy - PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work - PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power - PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux - PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature - # Electrodynamics - PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current - PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density - PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge - PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage - PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance - PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance - PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance - PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity - PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux - PT.MFluxDensity: ( - 0.85, - 0.5, - 0.85, - 1.0, - ), # Light Magenta: Magnetic Flux Density - PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance - PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field - PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field - # Luminal - PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity - PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux - PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance - }[self] - - @functools.cached_property - def default_unit(self) -> list[Unit]: - PT = PhysicalType - return { - PT.NonPhysical: None, - # Global - PT.Time: spu.picosecond, - PT.Angle: spu.radian, - PT.SolidAngle: spu.steradian, - PT.Freq: terahertz, - PT.AngFreq: spu.radian * terahertz, - # Cartesian - PT.Length: spu.micrometer, - PT.Area: spu.um**2, - PT.Volume: spu.um**3, - # Mechanical - PT.Vel: spu.um / spu.second, - PT.Accel: spu.um / spu.second, - PT.Mass: spu.microgram, - PT.Force: micronewton, - PT.Pressure: millibar, - # Energy - PT.Work: spu.joule, - PT.Power: spu.watt, - PT.PowerFlux: spu.watt / spu.meter**2, - PT.Temp: spu.kelvin, - # Electrodynamics - PT.Current: spu.ampere, - PT.CurrentDensity: spu.ampere / spu.meter**2, - PT.Charge: spu.coulomb, - PT.Voltage: spu.volt, - PT.Capacitance: spu.farad, - PT.Impedance: spu.ohm, - PT.Conductance: spu.siemens, - PT.Conductivity: spu.siemens / spu.micrometer, - PT.MFlux: spu.weber, - PT.MFluxDensity: spu.tesla, - PT.Inductance: spu.henry, - PT.EField: spu.volt / spu.micrometer, - PT.HField: spu.ampere / spu.micrometer, - # Luminal - PT.LumIntensity: spu.candela, - PT.LumFlux: spu.candela * spu.steradian, - PT.Illuminance: spu.candela / spu.meter**2, - }[self] - - @functools.cached_property - def valid_units(self) -> list[Unit]: - """Retrieve an ordered (by subjective usefulness) list of units for this physical type. - - Notes: - The order in which valid units are declared is the exact same order that UI dropdowns display them. - - **Altering the order of units breaks backwards compatibility**. - """ - PT = PhysicalType - return { - PT.NonPhysical: [None], - # Global - PT.Time: [ - spu.picosecond, - femtosecond, - spu.nanosecond, - spu.microsecond, - spu.millisecond, - spu.second, - spu.minute, - spu.hour, - spu.day, - ], - PT.Angle: [ - spu.radian, - spu.degree, - ], - PT.SolidAngle: [ - spu.steradian, - ], - PT.Freq: ( - _valid_freqs := [ - terahertz, - spu.hertz, - kilohertz, - megahertz, - gigahertz, - petahertz, - exahertz, - ] - ), - PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], - # Cartesian - PT.Length: ( - _valid_lens := [ - spu.micrometer, - spu.nanometer, - spu.picometer, - spu.angstrom, - spu.millimeter, - spu.centimeter, - spu.meter, - spu.inch, - spu.foot, - spu.yard, - spu.mile, - ] - ), - PT.Area: [_unit**2 for _unit in _valid_lens], - PT.Volume: [_unit**3 for _unit in _valid_lens], - # Mechanical - PT.Vel: [_unit / spu.second for _unit in _valid_lens], - PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], - PT.Mass: [ - spu.kilogram, - spu.electron_rest_mass, - spu.dalton, - spu.microgram, - spu.milligram, - spu.gram, - spu.metric_ton, - ], - PT.Force: [ - micronewton, - nanonewton, - millinewton, - spu.newton, - spu.kg * spu.meter / spu.second**2, - ], - PT.Pressure: [ - spu.bar, - millibar, - spu.pascal, - hectopascal, - spu.atmosphere, - spu.psi, - spu.mmHg, - spu.torr, - ], - # Energy - PT.Work: [ - spu.joule, - spu.electronvolt, - ], - PT.Power: [ - spu.watt, - ], - PT.PowerFlux: [ - spu.watt / spu.meter**2, - ], - PT.Temp: [ - spu.kelvin, - ], - # Electrodynamics - PT.Current: [ - spu.ampere, - ], - PT.CurrentDensity: [ - spu.ampere / spu.meter**2, - ], - PT.Charge: [ - spu.coulomb, - ], - PT.Voltage: [ - spu.volt, - ], - PT.Capacitance: [ - spu.farad, - ], - PT.Impedance: [ - spu.ohm, - ], - PT.Conductance: [ - spu.siemens, - ], - PT.Conductivity: [ - spu.siemens / spu.micrometer, - spu.siemens / spu.meter, - ], - PT.MFlux: [ - spu.weber, - ], - PT.MFluxDensity: [ - spu.tesla, - ], - PT.Inductance: [ - spu.henry, - ], - PT.EField: [ - spu.volt / spu.micrometer, - spu.volt / spu.meter, - ], - PT.HField: [ - spu.ampere / spu.micrometer, - spu.ampere / spu.meter, - ], - # Luminal - PT.LumIntensity: [ - spu.candela, - ], - PT.LumFlux: [ - spu.candela * spu.steradian, - ], - PT.Illuminance: [ - spu.candela / spu.meter**2, - ], - }[self] - - @staticmethod - def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None: - """Attempt to determine a matching `PhysicalType` from a unit. - - NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`. - - Returns: - The matched `PhysicalType`. - - If none could be matched, then either return `None` (if `optional` is set) or error. - - Raises: - ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. - """ - if unit is None: - return PhysicalType.NonPhysical - - ## TODO_ This enough? - if unit in [spu.radian, spu.degree]: - return PhysicalType.Angle - - unit_dim_deps = unit_to_unit_dim_deps(unit) - if unit_dim_deps is not None: - for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): - if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps): - return physical_type - - if optional: - return None - msg = f'Could not determine PhysicalType for {unit}' - raise ValueError(msg) - - @staticmethod - def from_unit_dim( - unit_dim: SympyType | None, optional: bool = False - ) -> typ.Self | None: - """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`. - - For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`). - To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions. - - Returns: - The matched `PhysicalType`. - - If none could be matched, then either return `None` (if `optional` is set) or error. - - Raises: - ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. - """ - for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): - if compare_unit_dims(unit_dim, candidate_unit_dim): - return physical_type - - if optional: - return None - msg = f'Could not determine PhysicalType for {unit_dim}' - raise ValueError(msg) - - @functools.cached_property - def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]: - PT = PhysicalType - overrides = { - # Cartesian - PT.Length: [None, (2,), (3,)], - # Mechanical - PT.Vel: [None, (2,), (3,)], - PT.Accel: [None, (2,), (3,)], - PT.Force: [None, (2,), (3,)], - # Energy - PT.Work: [None, (2,), (3,)], - PT.PowerFlux: [None, (2,), (3,)], - # Electrodynamics - PT.CurrentDensity: [None, (2,), (3,)], - PT.MFluxDensity: [None, (2,), (3,)], - PT.EField: [None, (2,), (3,)], - PT.HField: [None, (2,), (3,)], - # Luminal - PT.LumFlux: [None, (2,), (3,)], - } - - return overrides.get(self, [None]) - - @functools.cached_property - def valid_mathtypes(self) -> list[MathType]: - """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. - - Generally, all unit quantities are real, in the algebraic mathematical sense. - However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. - This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. - In general, the value is a phasor. - - While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. - - Notes: - - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. - - **Current**/**Voltage**: The imaginary part represents the phase. - This also holds for any downstream units. - - **Charge**: Generally, it is real. - However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: - - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. - - """ - MT = MathType - PT = PhysicalType - overrides = { - PT.NonPhysical: list(MT), ## Support All - # Cartesian - PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping - PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping - # Mechanical - # Energy - # Electrodynamics - PT.Current: [MT.Real, MT.Complex], ## Im -> Phase - PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase - PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase - PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase - PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase - PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance - PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction - PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction - PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction - PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase - PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase - PT.EField: [MT.Real, MT.Complex], ## Im -> Phase - PT.HField: [MT.Real, MT.Complex], ## Im -> Phase - # Luminal - } - - return overrides.get(self, [MT.Real]) - - @staticmethod - def to_name(value: typ.Self) -> str: - if value is PhysicalType.NonPhysical: - return 'Unitless' - return PhysicalType(value).name - - @staticmethod - def to_icon(value: typ.Self) -> str: - return '' - - def bl_enum_element(self, i: int) -> ct.BLEnumElement: - PT = PhysicalType - return ( - str(self), - PT.to_name(self), - PT.to_name(self), - PT.to_icon(self), - i, - ) - - -#################### -# - Standard Unit Systems -#################### -UnitSystem: typ.TypeAlias = dict[PhysicalType, Unit] - -_PT = PhysicalType -UNITS_SI: UnitSystem = { - _PT.NonPhysical: None, - # Global - _PT.Time: spu.second, - _PT.Angle: spu.radian, - _PT.SolidAngle: spu.steradian, - _PT.Freq: spu.hertz, - _PT.AngFreq: spu.radian * spu.hertz, - # Cartesian - _PT.Length: spu.meter, - _PT.Area: spu.meter**2, - _PT.Volume: spu.meter**3, - # Mechanical - _PT.Vel: spu.meter / spu.second, - _PT.Accel: spu.meter / spu.second**2, - _PT.Mass: spu.kilogram, - _PT.Force: spu.newton, - # Energy - _PT.Work: spu.joule, - _PT.Power: spu.watt, - _PT.PowerFlux: spu.watt / spu.meter**2, - _PT.Temp: spu.kelvin, - # Electrodynamics - _PT.Current: spu.ampere, - _PT.CurrentDensity: spu.ampere / spu.meter**2, - _PT.Voltage: spu.volt, - _PT.Capacitance: spu.farad, - _PT.Impedance: spu.ohm, - _PT.Conductance: spu.siemens, - _PT.Conductivity: spu.siemens / spu.meter, - _PT.MFlux: spu.weber, - _PT.MFluxDensity: spu.tesla, - _PT.Inductance: spu.henry, - _PT.EField: spu.volt / spu.meter, - _PT.HField: spu.ampere / spu.meter, - # Luminal - _PT.LumIntensity: spu.candela, - _PT.LumFlux: lumen, - _PT.Illuminance: spu.lux, -} - - -#################### -# - Sympy Utilities: Cast to Python -#################### -def sympy_to_python( - scalar: sp.Basic, use_jax_array: bool = False -) -> int | float | complex | tuple | jax.Array: - """Convert a scalar sympy expression to the directly corresponding Python type. - - Arguments: - scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. - For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. - - Returns: - A pure Python type that directly corresponds to the input scalar expression. - """ - if isinstance(scalar, sp.MatrixBase): - # Detect Single Column Vector - ## --> Flatten to Single Row Vector - if len(scalar.shape) == 2 and scalar.shape[1] == 1: - _scalar = scalar.T - else: - _scalar = scalar - - # Convert to Tuple of Tuples - matrix = tuple( - [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] - ) - - # Detect Single Row Vector - ## --> This could be because the scalar had it. - ## --> This could also be because we flattened a column vector. - ## Either way, we should strip the pointless dimensions. - if len(matrix) == 1: - return matrix[0] if not use_jax_array else jnp.array(matrix[0]) - - return matrix if not use_jax_array else jnp.array(matrix) - if scalar.is_integer: - return int(scalar) - if scalar.is_rational or scalar.is_real: - return float(scalar) - if scalar.is_complex: - return complex(scalar) - - msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 - raise ValueError(msg) - - -#################### -# - Convert to Unit System -#################### -def strip_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None = None -) -> SympyExpr: - """Strip units occurring in the given unit system from the expression. - - Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". - Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_. - - Notes: - You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. - """ - if unit_system is None: - return sp_obj.subs(UNIT_TO_1) - return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) - - -def convert_to_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None -) -> SympyExpr: - """Convert an expression to the units of a given unit system.""" - if unit_system is None: - return sp_obj - - return spu.convert_to( - sp_obj, - {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, - ) - - -def scale_to_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False -) -> int | float | complex | tuple | jax.Array: - """Convert an expression to the units of a given unit system, then strip all units of the unit system. - - Afterwards, it is converted to an appropriate Python type. - - Notes: - For stability, and performance, reasons, this should only be used at the very last stage. - - Regarding performance: **This is not a fast function**. - - Parameters: - sp_obj: An arbitrary sympy object, presumably with units. - unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. - Note that, in this context, only `unit_system.values()` is used. - - Returns: - An appropriate pure Python type, after scaling to the unit system and stripping all units away. - - If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. - """ - return sympy_to_python( - strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system), - use_jax_array=use_jax_array, - ) diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index 6ca4412..68820c3 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -30,7 +30,7 @@ import matplotlib.figure import seaborn as sns from blender_maxwell import contracts as ct -from blender_maxwell.utils import extra_sympy_units as spux +from blender_maxwell.utils import sympy_extra as spux from blender_maxwell.utils import logger sns.set_theme() diff --git a/src/blender_maxwell/utils/sci_constants.py b/src/blender_maxwell/utils/sci_constants.py index 920bb2d..935e24a 100644 --- a/src/blender_maxwell/utils/sci_constants.py +++ b/src/blender_maxwell/utils/sci_constants.py @@ -32,7 +32,7 @@ import scipy as sc import sympy as sp import sympy.physics.units as spu -from . import extra_sympy_units as spux +from . import sympy_extra as spux SUPPORTED_SCIPY_PREFIX = '1.12' if not sc.version.full_version.startswith(SUPPORTED_SCIPY_PREFIX): diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py index bab40a0..b58aebc 100644 --- a/src/blender_maxwell/utils/serialize.py +++ b/src/blender_maxwell/utils/serialize.py @@ -31,8 +31,8 @@ import uuid import msgspec import sympy as sp -from . import extra_sympy_units as spux from . import logger +from . import sympy_extra as spux log = logger.get(__name__) diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py index d3e4366..f663d36 100644 --- a/src/blender_maxwell/utils/sim_symbols.py +++ b/src/blender_maxwell/utils/sim_symbols.py @@ -25,8 +25,8 @@ import jaxtyping as jtyp import pydantic as pyd import sympy as sp -from . import extra_sympy_units as spux from . import logger, serialize +from . import sympy_extra as spux int_min = -(2**64) int_max = 2**64 @@ -101,6 +101,12 @@ class SimSymbolName(enum.StrEnum): BlochY = enum.auto() BlochZ = enum.auto() + # New Backwards Compatible Entries + ## -> Ordered lists carry a particular enum integer index. + ## -> Therefore, anything but adding an index breaks backwards compat. + ## -> ...With all previous files. + ConstantRange = enum.auto() + #################### # - UI #################### @@ -143,7 +149,8 @@ class SimSymbolName(enum.StrEnum): } | { # Generic - SSN.Constant: 'constant', + SSN.Constant: 'cst', + SSN.ConstantRange: 'cst_range', SSN.Expr: 'expr', SSN.Data: 'data', # Greek Letters @@ -210,24 +217,43 @@ def mk_interval( interval_finite: tuple[int | Fraction | float, int | Fraction | float], interval_inf: tuple[bool, bool], interval_closed: tuple[bool, bool], - unit_factor: typ.Literal[1] | spux.Unit, ) -> sp.Interval: """Create a symbolic interval from the tuples (and unit) defining it.""" return sp.Interval( - start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo), - end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo), + start=(interval_finite[0] if not interval_inf[0] else -sp.oo), + end=(interval_finite[1] if not interval_inf[1] else sp.oo), left_open=(True if interval_inf[0] else not interval_closed[0]), right_open=(True if interval_inf[1] else not interval_closed[1]), ) class SimSymbol(pyd.BaseModel): - """A declarative representation of a symbolic variable. + """A convenient, constrained representation of a symbolic variable suitable for many tasks. - `sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof. + The original motivation was to enhance `sp.Symbol` with greater flexibility, semantic context, and a UI-friendly representation. + Today, `SimSymbol` is a fully capable primitive for defining the interfaces between externally tracked mathematical elements, and planning the required operations between them. + + A symbol represented as `SimSymbol` carries all the semantic meaning of that symbol, and comes with a comprehensive library of useful (computed) properties and methods. + It is immutable, hashable, and serializable, and as a `pydantic.BaseModel` with aggressive property caching, its performance properties should also be well-suited for use in the hot-loops of ex. UI draw methods. + + Attributes: + sym_name: For humans and computers, symbol names induces a lot of implicit semantics. + mathtype: Symbols are associated with some set of valid values. + We choose to constrain `SimSymbol` to only associate with _mathematical_ (aka. number-like) sets. + This prohibits ex. booleans and predicate-logic applications, but eases a lot of burdens associated with actually using `SimSymbol`. + physical_type: Symbols may be associated with a particular unit dimension expression. + This allows the symbol to have _physical meaning_. + This information is **generally not** encoded in auxiliary attributes like `self.domain`, but **generally is** encoded by computed properties/methods. + unit: Symbols may be associated with a particular unit, which must be compatible with the `PhysicalType`. + **NOTE**: Unit expressions may well have physical meaning, without being strictly conformable to a pre-blessed `PhysicalType`s. + We do try to avoid such cases, but for the sake of correctness, our chosen convention is to let `self.physical_type` be "`NonPhysical`", while still allowing a unit. + size: Symbols may themselves have shape. + **NOTE**: We deliberately choose to constrain `SimSymbol`s to two dimensions, allowing them to represent scalars, vectors, covectors, and matrices, but **not** arbitrary tensors. + This is a practical tradeoff, made both to make it easier (in terms of mathematical analysis) to implement `SimSymbol`, but also to make it easier to define UI elements that drive / are driven by `SimSymbol`s. + domain: Symbols are associated with a _domain of valid values_, expressed with any mathematical set implemented as a subclass of `sympy.Set`. + By using a true symbolic set, we gain unbounded flexibility in how to define the validity of a set, including an extremely capable `* in self.domain` operator encapsulating a lot of otherwise very manual logic. + **NOTE** that `self.unit` is **not** baked into the domain, due to practicalities associated with subclasses of `sp.Set`. - This dataclass is UI-friendly, as it only uses field type annotations/defaults supported by `bl_cache.BLProp`. - It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. """ model_config = pyd.ConfigDict(frozen=True) @@ -238,11 +264,8 @@ class SimSymbol(pyd.BaseModel): # Units ## -> 'None' indicates that no particular unit has yet been chosen. - ## -> Not exposed in the UI; must be set some other way. + ## -> When 'self.physical_type' is NonPhysical, can only be None. unit: spux.Unit | None = None - ## -> TODO: We currently allowing units that don't match PhysicalType - ## -> -- In particular, NonPhysical w/units means "unknown units". - ## -> -- This is essential for the Scientific Constant Node. # Size ## -> All SimSymbol sizes are "2D", but interpreted by convention. @@ -253,39 +276,96 @@ class SimSymbol(pyd.BaseModel): rows: int = 1 cols: int = 1 - # Scalar Domain: "Interval" - ## -> NOTE: interval_finite_*[0] must be strictly smaller than [1]. - ## -> See self.domain. - ## -> We have to deconstruct symbolic interval semantics a bit for UI. - is_constant: bool = False - exclude_zero: bool = False + # Valid Domain + ## -> Declares the valid set of values that may be given to this symbol. + ## -> By convention, units are not encoded in the domain sp.Set. + ## -> 'sp.Set's are extremely expressive and cool. + domain: spux.SympyExpr | None = None - interval_finite_z: tuple[int, int] = (0, 1) - interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1)) - interval_finite_re: tuple[float, float] = (0.0, 1.0) - interval_inf: tuple[bool, bool] = (True, True) - interval_closed: tuple[bool, bool] = (False, False) + @functools.cached_property + def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet: + if self.rows > 1 or self.cols > 1: + return sp.matrices.MatrixSet(self.rows, self.cols, self.domain) + return self.domain - interval_finite_im: tuple[float, float] = (0.0, 1.0) - interval_inf_im: tuple[bool, bool] = (True, True) - interval_closed_im: tuple[bool, bool] = (False, False) - - preview_value_z: int = 0 - preview_value_q: tuple[int, int] = (0, 1) - preview_value_re: float = 0.0 - preview_value_im: float = 0.0 + preview_value: spux.SympyExpr | None = None #################### - # - Core + # - Validators + #################### + ## TODO: Check domain against MathType + ## -- Surprisingly hard without a lot of special-casing. + + ## TODO: Check that size is valid for the PhysicalType. + + ## TODO: Check that constant value (domain=FiniteSet(cst)) is compatible with the MathType. + + ## TODO: Check that preview_value is in the domain. + + @pyd.model_validator(mode='after') + def set_undefined_domain_from_mathtype(self) -> typ.Self: + """When the domain is not set, then set it using the symbolic set of the MathType.""" + if self.domain is None: + object.__setattr__(self, 'domain', self.mathtype.symbolic_set) + return self + + @pyd.model_validator(mode='after') + def conform_undefined_preview_value_to_constant(self) -> typ.Self: + """When the `SimSymbol` is a constant, but the preview value is not set, then set the preview value from the constant.""" + if self.is_constant and not self.preview_value: + object.__setattr__(self, 'preview_value', self.constant_value) + return self + + @pyd.model_validator(mode='after') + def conform_preview_value(self) -> typ.Self: + """Conform the given preview value to the `SimSymbol`.""" + if self.is_constant and not self.preview_value: + object.__setattr__( + self, + 'preview_value', + self.conform(self.preview_value, strip_units=True), + ) + return self + + #################### + # - Domain #################### @functools.cached_property - def name(self) -> str: - """Usable name for the symbol.""" - return self.sym_name.name + def is_constant(self) -> bool: + """When the symbol domain is a single-element `sp.FiniteSet`, then the symbol can be considered to be a constant.""" + return isinstance(self.domain, sp.FiniteSet) and len(self.domain) == 1 + + @functools.cached_property + def constant_value(self) -> bool: + """Get the constant when `is_constant` is True. + + The `self.unit_factor` is multiplied onto the constant at this point. + """ + if self.is_constant: + return next(iter(self.domain)) * self.unit_factor + + msg = 'Tried to get constant value of non-constant SimSymbol.' + raise ValueError(msg) + + @functools.cached_property + def is_nonzero(self) -> bool: + """Whether $0$ is a valid value for this symbol. + + When shaped, $0$ refers to the relevant shaped object with all elements $0$. + + Notes: + Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`. + """ + return 0 in self.domain #################### # - Labels #################### + @functools.cached_property + def name(self) -> str: + """Usable string name for the symbol.""" + return self.sym_name.name + @functools.cached_property def name_pretty(self) -> str: """Pretty (possibly unicode) name for the thing.""" @@ -340,7 +420,8 @@ class SimSymbol(pyd.BaseModel): return self.unit if self.unit is not None else sp.S(1) @functools.cached_property - def size(self) -> tuple[int, ...] | None: + def size(self) -> spux.NumberSize1D | None: + """The 1D number size of this `SimSymbol`, if it has one; else None.""" return { (1, 1): spux.NumberSize1D.Scalar, (2, 1): spux.NumberSize1D.Vec2, @@ -350,13 +431,17 @@ class SimSymbol(pyd.BaseModel): @functools.cached_property def shape(self) -> tuple[int, ...]: + """Deterministic chosen shape of this `SimSymbol`. + + Derived from `self.rows` and `self.cols`. + + Is never `None`; instead, empty tuple `()` is used. + """ match (self.rows, self.cols): case (1, 1): return () case (_, 1): return (self.rows,) - case (1, _): - return (1, self.rows) case (_, _): return (self.rows, self.cols) @@ -365,116 +450,6 @@ class SimSymbol(pyd.BaseModel): """Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking.""" return len(self.shape) - @functools.cached_property - def domain(self) -> sp.Interval | sp.Set: - """Return the scalar domain of valid values for each element of the symbol. - - For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties. - This interval **must** have the property`start <= stop`. - - Otherwise, the domain is the symbolic set corresponding to `self.mathtype`. - """ - match self.mathtype: - case spux.MathType.Integer: - return mk_interval( - self.interval_finite_z, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Rational: - return mk_interval( - Fraction(*self.interval_finite_q), - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Real: - return mk_interval( - self.interval_finite_re, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ) - - case spux.MathType.Complex: - return ( - mk_interval( - self.interval_finite_re, - self.interval_inf, - self.interval_closed, - self.unit_factor, - ), - mk_interval( - self.interval_finite_im, - self.interval_inf_im, - self.interval_closed_im, - self.unit_factor, - ), - ) - - @functools.cached_property - def valid_domain_value(self) -> spux.SympyExpr: - """A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`.""" - match (self.domain.start.is_finite, self.domain.end.is_finite): - case (True, True): - if self.mathtype is spux.MathType.Integer: - return (self.domain.start + self.domain.end) // 2 - return (self.domain.start + self.domain.end) / 2 - - case (True, False): - one = sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - return self.domain.start + one - - case (False, True): - one = sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - return self.domain.end - one - - case (False, False): - return sp.S(self.mathtype.coerce_compatible_pyobj(-1)) - - @functools.cached_property - def is_nonzero(self) -> bool: - """Whether or not the value of this symbol can ever be $0$. - - Notes: - Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`. - """ - if self.exclude_zero: - return True - - def check_real_domain(real_domain): - return ( - ( - real_domain.left == 0 - and real_domain.left_open - or real_domain.right == 0 - and real_domain.right_open - ) - or real_domain.left > 0 - or real_domain.right < 0 - ) - - if self.mathtype is spux.MathType.Complex: - return check_real_domain(self.domain[0]) and check_real_domain( - self.domain[1] - ) - return check_real_domain(self.domain) - - @functools.cached_property - def can_diff(self) -> bool: - """Whether this symbol can be used as the input / output variable when differentiating.""" - # Check Constants - ## -> Constants (w/pinned values) are never differentiable. - if self.is_constant: - return False - - # TODO: Discontinuities (especially across 0)? - - return self.mathtype in [spux.MathType.Real, spux.MathType.Complex] - #################### # - Properties #################### @@ -511,9 +486,9 @@ class SimSymbol(pyd.BaseModel): # Positive/Negative Assumption if self.mathtype is not spux.MathType.Complex: - if self.domain.left >= 0: + if self.domain.inf >= 0: mathtype_kwargs |= {'positive': True} - elif self.domain.right <= 0: + elif self.domain.sup < 0: mathtype_kwargs |= {'negative': True} # Scalar: Return Symbol @@ -571,7 +546,7 @@ class SimSymbol(pyd.BaseModel): """ if self.size is not None: if self.unit in self.physical_type.valid_units: - return { + socket_info = { 'output_name': self.sym_name, # Socket Interface 'size': self.size, @@ -580,23 +555,42 @@ class SimSymbol(pyd.BaseModel): # Defaults: Units 'default_unit': self.unit, 'default_symbols': [], - # Defaults: FlowKind.Value - 'default_value': self.conform( - self.valid_domain_value, strip_unit=True - ), - # Defaults: FlowKind.Range - 'default_min': self.conform(self.domain.start, strip_unit=True), - 'default_max': self.conform(self.domain.end, strip_unit=True), } + + # Defaults: FlowKind.Value + if self.preview_value: + socket_info |= { + 'default_value': self.conform( + self.preview_value, strip_unit=True + ) + } + + # Defaults: FlowKind.Range + if ( + self.mathtype is not spux.MathType.Complex + and self.rows == 1 + and self.cols == 1 + ): + socket_info |= { + 'default_min': self.domain.inf, + 'default_max': self.domain.sup, + } + ## TODO: Handle discontinuities / disjointness / open boundaries. + msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})' raise NotImplementedError(msg) + msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})' raise NotImplementedError(msg) #################### - # - Operations + # - Operations: Raw Update #################### def update(self, **kwargs) -> typ.Self: + """Create a new `SimSymbol`, such that the given keyword arguments override the existing values.""" + if not kwargs: + return self + def get_attr(attr: str): _notfound = 'notfound' if kwargs.get(attr, _notfound) is _notfound: @@ -610,61 +604,101 @@ class SimSymbol(pyd.BaseModel): unit=get_attr('unit'), rows=get_attr('rows'), cols=get_attr('cols'), - interval_finite_z=get_attr('interval_finite_z'), - interval_finite_q=get_attr('interval_finite_q'), - interval_finite_re=get_attr('interval_finite_re'), - interval_inf=get_attr('interval_inf'), - interval_closed=get_attr('interval_closed'), - interval_finite_im=get_attr('interval_finite_im'), - interval_inf_im=get_attr('interval_inf_im'), - interval_closed_im=get_attr('interval_closed_im'), + domain=get_attr('domain'), ) - def set_finite_domain( # noqa: PLR0913 - self, - start: int | float, - end: int | float, - start_closed: bool = True, - end_closed: bool = True, - start_im: bool = float, - end_im: bool = float, - start_closed_im: bool = True, - end_closed_im: bool = True, - ) -> typ.Self: - """Update the symbol with a finite range.""" - closed_re = (start_closed, end_closed) - closed_im = (start_closed_im, end_closed_im) - match self.mathtype: - case spux.MathType.Integer: - return self.update( - interval_finite_z=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Rational: - return self.update( - interval_finite_q=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Real: - return self.update( - interval_finite_re=(start, end), - interval_inf=(False, False), - interval_closed=closed_re, - ) - case spux.MathType.Complex: - return self.update( - interval_finite_re=(start, end), - interval_finite_im=(start_im, end_im), - interval_inf=(False, False), - interval_closed=closed_re, - interval_closed_im=closed_im, - ) + #################### + # - Operations: Comparison + #################### + def compare(self, other: typ.Self) -> typ.Self: + """Whether this SimSymbol can be considered equivalent to another, and thus universally usable in arbitrary mathematical operations together. - def set_size(self, rows: int, cols: int) -> typ.Self: - return self.update(rows=rows, cols=cols) + In particular, two attributes are ignored: + - **Name**: The particluar choices of name are not generally important. + - **Unit**: The particulars of unit equivilancy are not generally important; only that the `PhysicalType` is equal, and thus that they are compatible. + While not usable in all cases, this method ends up being very helpful for simplifying certain checks that would otherwise take up a lot of space. + """ + return ( + self.mathtype is other.mathtype + and self.physical_type is other.physical_type + and self.compare_size(other) + and self.domain == other.domain + ) + + def compare_size(self, other: typ.Self) -> typ.Self: + """Compare the size of this `SimSymbol` with another.""" + return self.rows == other.rows and self.cols == other.cols + + def compare_addable( + self, other: typ.Self, allow_differing_unit: bool = False + ) -> bool: + """Whether two `SimSymbol`s can be added.""" + common = ( + self.compare_size(other.output) + and self.physical_type is other.physical_type + and not ( + self.physical_type is spux.NonPhysical + and self.unit is not None + and self.unit != other.unit + ) + and not ( + other.physical_type is spux.NonPhysical + and other.unit is not None + and self.unit != other.unit + ) + ) + if not allow_differing_unit: + return common and self.output.unit == other.output.unit + return common + + def compare_multiplicable(self, other: typ.Self) -> bool: + """Whether two `SimSymbol`s can be multiplied.""" + return self.shape_len == 0 or self.compare_size(other) + + def compare_exponentiable(self, other: typ.Self) -> bool: + """Whether two `SimSymbol`s can be exponentiated. + + "Hadamard Power" is defined for any combination of scalar/vector/matrix operands, for any `MathType` combination. + The only important thing to check is that the exponent cannot have a physical unit. + + Sometimes, people write equations with units in the exponent. + This is a notational shorthand that only works in the context of an implicit, cancelling factor. + We reject such things. + + See https://physics.stackexchange.com/questions/109995/exponential-or-logarithm-of-a-dimensionful-quantity + """ + return ( + other.physical_type is spux.PhysicalType.NonPhysical and other.unit is None + ) + + #################### + # - Operations: Copying Setters + #################### + def set_constant(self, constant_value: spux.SympyType) -> typ.Self: + """Set the constant value of this `SimSymbol`, by setting it as the only value in a `sp.FiniteSet` domain. + + The `constant_value` will be conformed and stripped (with `self.conform()`) before being injected into the new `sp.FiniteSet` domain. + + Warnings: + Keep in mind that domains do not encode units, for practical reasons related to the diverging ways in which various `sp.Set` subclasses interpret units. + + This isn't noticeable in normal constant-symbol workflows, where the constant is retrieved using `self.constant_value` (which adds `self.unit_factor`). + However, **remember that retrieving the domain directly won't add the unit**. + + Ye been warned! + """ + if self.is_constant: + return self.update( + domain=sp.FiniteSet(self.conform(constant_value, strip_unit=True)) + ) + + msg = 'Tried to set constant value of non-constant SimSymbol.' + raise ValueError(msg) + + #################### + # - Operations: Conforming Mappers + #################### def conform( self, sp_obj: spux.SympyType, strip_unit: bool = False ) -> spux.SympyType: @@ -732,6 +766,9 @@ class SimSymbol(pyd.BaseModel): return res # noqa: RET504 + #################### + # - Creation + #################### @staticmethod def from_expr( sym_name: SimSymbolName, diff --git a/src/blender_maxwell/utils/sympy_extra/__init__.py b/src/blender_maxwell/utils/sympy_extra/__init__.py new file mode 100644 index 0000000..6ac4fca --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/__init__.py @@ -0,0 +1,173 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Declares many useful primitives to greatly simplify working with `sympy` in the context of a unit-aware system.""" + +from .math_type import MathType +from .number_size import NumberSize1D, NumberSize2D +from .parse_cast import parse_shape, pretty_symbol, sp_to_str, sympy_to_python +from .physical_type import Dims, PhysicalType +from .sympy_expr import ( + ComplexNumber, + ComplexSymbol, + ConstrSympyExpr, + IntNumber, + IntSymbol, + Number, + PhysicalComplexNumber, + PhysicalNumber, + PhysicalRealNumber, + RationalSymbol, + Real3DVector, + RealNumber, + RealSymbol, + ScalarUnitlessComplexExpr, + ScalarUnitlessRealExpr, + Symbol, + SympyExpr, + Unit, + UnitDimension, +) +from .sympy_type import SympyType +from .unit_analysis import ( + compare_unit_dims, + compare_units_by_unit_dims, + convert_to_unit, + get_units, + scale_to_unit, + scaling_factor, + strip_units, + unit_dim_to_unit_dim_deps, + unit_str_to_unit, + unit_to_unit_dim_deps, + uses_units, +) +from .unit_system_analysis import ( + convert_to_unit_system, + scale_to_unit_system, + strip_unit_system, +) +from .unit_systems import UNITS_SI, UnitSystem +from .units import ( + UNIT_BY_SYMBOL, + UNIT_TO_1, + EHz, + GHz, + KHz, + MHz, + PHz, + THz, + exahertz, + femtometer, + femtosecond, + fm, + fs, + gigahertz, + hectopascal, + hPa, + kilohertz, + lm, + lumen, + mbar, + megahertz, + micronewton, + millibar, + millinewton, + mN, + nanonewton, + nN, + petahertz, + terahertz, + uN, +) + +__all__ = [ + 'MathType', + 'NumberSize1D', + 'NumberSize2D', + 'parse_shape', + 'pretty_symbol', + 'sp_to_str', + 'sympy_to_python', + 'Dims', + 'PhysicalType', + 'ComplexNumber', + 'ComplexSymbol', + 'ConstrSympyExpr', + 'IntNumber', + 'IntSymbol', + 'Number', + 'PhysicalComplexNumber', + 'PhysicalNumber', + 'PhysicalRealNumber', + 'RationalSymbol', + 'Real3DVector', + 'RealNumber', + 'RealSymbol', + 'ScalarUnitlessComplexExpr', + 'ScalarUnitlessRealExpr', + 'Symbol', + 'SympyExpr', + 'Unit', + 'UnitDimension', + 'SympyType', + 'compare_unit_dims', + 'compare_units_by_unit_dims', + 'convert_to_unit', + 'get_units', + 'scale_to_unit', + 'scaling_factor', + 'strip_units', + 'unit_dim_to_unit_dim_deps', + 'unit_str_to_unit', + 'unit_to_unit_dim_deps', + 'uses_units', + 'strip_unit_system', + 'UNITS_SI', + 'UnitSystem', + 'convert_to_unit_system', + 'scale_to_unit_system', + 'UNIT_BY_SYMBOL', + 'UNIT_TO_1', + 'EHz', + 'GHz', + 'KHz', + 'MHz', + 'PHz', + 'THz', + 'exahertz', + 'femtometer', + 'femtosecond', + 'fm', + 'fs', + 'gigahertz', + 'hectopascal', + 'hPa', + 'kilohertz', + 'lm', + 'lumen', + 'mbar', + 'megahertz', + 'micronewton', + 'millibar', + 'millinewton', + 'mN', + 'nanonewton', + 'nN', + 'petahertz', + 'terahertz', + 'uN', +] diff --git a/src/blender_maxwell/utils/sympy_extra/math_type.py b/src/blender_maxwell/utils/sympy_extra/math_type.py new file mode 100644 index 0000000..8830d85 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/math_type.py @@ -0,0 +1,362 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Implements `MathType`, a convenient UI-friendly identifier of numerical identity.""" + +import enum +import sys +import typing as typ +from fractions import Fraction + +import jax +import jaxtyping as jtyp +import sympy as sp + +from blender_maxwell import contracts as ct + +from .. import logger +from .sympy_type import SympyType + +log = logger.get(__name__) + + +class MathType(enum.StrEnum): + """A convenient, UI-friendly identifier of a numerical object's identity.""" + + Integer = enum.auto() + Rational = enum.auto() + Real = enum.auto() + Complex = enum.auto() + + #################### + # - Checks + #################### + @staticmethod + def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'jax', 'expr'] | None: + """Determine whether an object of arbitrary type can be considered to have a `MathType`. + + - **Pure Python**: The numerical Python types (`int | Fraction | float | complex`) are all valid. + - **Expression**: Sympy types / expression are in general considered to have a valid MathType. + - **Jax**: Non-empty `jax` arrays with a valid numerical Python type as the first element are valid. + + Returns: + A string literal indicating how to parse the object for a valid `MathType`. + + If the presence of a MathType couldn't be deduced, then return None. + """ + if isinstance(obj, int | Fraction | float | complex): + return 'pytype' + + if ( + isinstance(obj, jax.Array) + and obj + and isinstance(obj.item(0), int | Fraction | float | complex) + ): + return 'jax' + + if isinstance(obj, sp.Basic | sp.MatrixBase): + return 'expr' + ## TODO: Should we check deeper? + + return None + + #################### + # - Creation + #################### + @staticmethod + def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: # noqa: PLR0911 + """Deduce the `MathType` of an arbitrary sympy object (/expression). + + The "assumptions" system of `sympy` is relied on to determine the key properties of the expression. + To this end, it's important to note several of the shortcomings of the "assumptions" system: + + - All elements, especially symbols, must have well-defined assumptions, ex. `real=True`. + - Only the "narrowest" possible `MathType` will be deduced, ex. `5` may well be the result of a complex expression, but since it is now an integer, it will parse to `MathType.Integer`. This may break some + - For infinities, only real and complex infinities are distinguished between in `sympy` (`sp.oo` vs. `sp.zoo`) - aka. there is no "integer infinity" which will parse to `Integer` with this method. + + Warnings: + Using the "assumptions" system like this requires a lot of rigor in the entire program. + + Notes: + Any matrix-like object will have `MathType.combine()` run on all of its (flattened) elements. + This is an extremely **slow** operation, but accurate, according to the semantics of `MathType.combine()`. + + Note that `sp.MatrixSymbol` _cannot have assumptions_, and thus shouldn't be used in `sp_obj`. + + Returns: + A corresponding `MathType`; else, if `optional=True`, return `None`. + + Raises: + ValueError: If no corresponding `MathType` could be determined, and `optional=False`. + + """ + if isinstance(sp_obj, sp.MatrixBase): + return MathType.combine( + *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] + ) + + if sp_obj.is_integer: + return MathType.Integer + if sp_obj.is_rational: + return MathType.Rational + if sp_obj.is_real: + return MathType.Real + if sp_obj.is_complex: + return MathType.Complex + + # Infinities + if sp_obj in [sp.oo, -sp.oo]: + return MathType.Real + if sp_obj in [sp.zoo, -sp.zoo]: + return MathType.Complex + + if optional: + return None + + msg = f"Can't determine MathType from sympy object: {sp_obj}" + raise ValueError(msg) + + @staticmethod + def from_pytype(dtype: type) -> type: + return { + int: MathType.Integer, + Fraction: MathType.Rational, + float: MathType.Real, + complex: MathType.Complex, + }[dtype] + + @staticmethod + def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type: + """Deduce the MathType corresponding to a JAX array. + + We go about this by leveraging that: + - `data` is of a homogeneous type. + - `data.item(0)` returns a single element of the array w/pure-python type. + + By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency. + + Notes: + Should also work with numpy arrays. + """ + if len(data) > 0: + return MathType.from_pytype(type(data.item(0))) + + msg = 'Cannot determine MathType from empty jax array.' + raise ValueError(msg) + + #################### + # - Operations + #################### + @staticmethod + def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None: + if MathType.Complex in mathtypes: + return MathType.Complex + if MathType.Real in mathtypes: + return MathType.Real + if MathType.Rational in mathtypes: + return MathType.Rational + if MathType.Integer in mathtypes: + return MathType.Integer + + if optional: + return None + + msg = f"Can't combine mathtypes {mathtypes}" + raise ValueError(msg) + + def is_compatible(self, other: typ.Self) -> bool: + MT = MathType + return ( + other + in { + MT.Integer: [MT.Integer], + MT.Rational: [MT.Integer, MT.Rational], + MT.Real: [MT.Integer, MT.Rational, MT.Real], + MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex], + }[self] + ) + + def coerce_compatible_pyobj( + self, pyobj: bool | int | Fraction | float | complex + ) -> int | Fraction | float | complex: + """Coerce a pure-python object of numerical type to the _exact_ type indicated by this `MathType`. + + This is needed when ex. one has an integer, but it is important that that integer be passed as a complex number. + """ + MT = MathType + match self: + case MT.Integer: + return int(pyobj) + case MT.Rational if isinstance(pyobj, int): + return Fraction(pyobj, 1) + case MT.Rational if isinstance(pyobj, Fraction): + return pyobj + case MT.Real: + return float(pyobj) + case MT.Complex if isinstance(pyobj, int | Fraction): + return complex(float(pyobj), 0) + case MT.Complex if isinstance(pyobj, float): + return complex(pyobj, 0) + + @staticmethod + def from_symbolic_set( + s: typ.Literal[ + sp.Naturals + | sp.Naturals0 + | sp.Integers + | sp.Rationals + | sp.Reals + | sp.Complexes + ] + | sp.Set, + optional: bool = False, + ) -> typ.Self | None: + """Deduce the `MathType` from a particular symbolic set. + + Currently hard-coded. + Any deviation that might be expected to work, ex. `sp.Reals - {0}`, currently won't (currently). + + Raises: + ValueError: If a non-hardcoded symbolic set is passed. + """ + MT = MathType + match s: + case sp.Naturals | sp.Naturals0 | sp.Integers: + return MT.Integer + case sp.Rationals: + return MT.Rational + case sp.Reals: + return MT.Real + case sp.Complexes: + return MT.Complex + + if optional: + return None + + msg = f"Can't deduce MathType from symbolic set {s}" + raise ValueError(msg) + + #################### + # - Casting: Pytype + #################### + @property + def pytype(self) -> type: + """Deduce the pure-Python type that corresponds to this `MathType`.""" + MT = MathType + return { + MT.Integer: int, + MT.Rational: Fraction, + MT.Real: float, + MT.Complex: complex, + }[self] + + @property + def inf_finite(self) -> type: + """Opinionated finite representation of "infinity" within this `MathType`. + + These are chosen using `sys.maxsize` and `sys.float_info`. + As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated. + + **Note** that, in practice, most systems will have no trouble working with values that exceed those defined here. + + Notes: + Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. . + + These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`). + In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway. + + However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context. + """ + MT = MathType + Z = MT.Integer + R = MT.Integer + return { + MT.Integer: (-sys.maxsize, sys.maxsize), + MT.Rational: ( + Fraction(Z.inf_finite[0], 1), + Fraction(Z.inf_finite[1], 1), + ), + MT.Real: -(sys.float_info.min, sys.float_info.max), + MT.Complex: ( + complex(R.inf_finite[0], R.inf_finite[0]), + complex(R.inf_finite[1], R.inf_finite[1]), + ), + }[self] + + #################### + # - Casting: Symbolic + #################### + @property + def symbolic_set(self) -> sp.Set: + """Deduce the symbolic `sp.Set` type that corresponds to this `MathType`.""" + MT = MathType + return { + MT.Integer: sp.Integers, + MT.Rational: sp.Rationals, + MT.Real: sp.Reals, + MT.Complex: sp.Complexes, + }[self] + + @property + def sp_symbol_a(self) -> type: + MT = MathType + return { + MT.Integer: sp.Symbol('a', integer=True), + MT.Rational: sp.Symbol('a', rational=True), + MT.Real: sp.Symbol('a', real=True), + MT.Complex: sp.Symbol('a', complex=True), + }[self] + + #################### + # - Labels + #################### + @staticmethod + def to_str(value: typ.Self) -> type: + return { + MathType.Integer: 'ℤ', + MathType.Rational: 'ℚ', + MathType.Real: 'ℝ', + MathType.Complex: 'ℂ', + }[value] + + @property + def name(self) -> str: + """Simple non-unicode name of the math type.""" + return str(self) + + @property + def label_pretty(self) -> str: + return MathType.to_str(self) + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + return MathType.to_str(value) + + @staticmethod + def to_icon(value: typ.Self) -> str: + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + MathType.to_name(self), + MathType.to_name(self), + MathType.to_icon(self), + i, + ) diff --git a/src/blender_maxwell/utils/sympy_extra/number_size.py b/src/blender_maxwell/utils/sympy_extra/number_size.py new file mode 100644 index 0000000..14bf34d --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/number_size.py @@ -0,0 +1,148 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import enum +import typing as typ + +import sympy as sp + +from blender_maxwell import contracts as ct + + +#################### +# - Size: 1D +#################### +class NumberSize1D(enum.StrEnum): + """Valid 1D-constrained shape.""" + + Scalar = enum.auto() + Vec2 = enum.auto() + Vec3 = enum.auto() + Vec4 = enum.auto() + + @staticmethod + def to_name(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: 'Scalar', + NS.Vec2: '2D', + NS.Vec3: '3D', + NS.Vec4: '4D', + }[value] + + @staticmethod + def to_icon(value: typ.Self) -> str: + NS = NumberSize1D + return { + NS.Scalar: '', + NS.Vec2: '', + NS.Vec3: '', + NS.Vec4: '', + }[value] + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + return ( + str(self), + NumberSize1D.to_name(self), + NumberSize1D.to_name(self), + NumberSize1D.to_icon(self), + i, + ) + + @staticmethod + def has_shape(shape: tuple[int, ...] | None): + return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)] + + def supports_shape(self, shape: tuple[int, ...] | None): + NS = NumberSize1D + match self: + case NS.Scalar: + return shape is None + case NS.Vec2: + return shape in ((2,), (2, 1)) + case NS.Vec3: + return shape in ((3,), (3, 1)) + case NS.Vec4: + return shape in ((4,), (4, 1)) + + @staticmethod + def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self: + NS = NumberSize1D + return { + None: NS.Scalar, + (2,): NS.Vec2, + (3,): NS.Vec3, + (4,): NS.Vec4, + (2, 1): NS.Vec2, + (3, 1): NS.Vec3, + (4, 1): NS.Vec4, + }[shape] + + @property + def rows(self): + NS = NumberSize1D + return { + NS.Scalar: 1, + NS.Vec2: 2, + NS.Vec3: 3, + NS.Vec4: 4, + }[self] + + @property + def cols(self): + return 1 + + @property + def shape(self): + NS = NumberSize1D + return { + NS.Scalar: None, + NS.Vec2: (2,), + NS.Vec3: (3,), + NS.Vec4: (4,), + }[self] + + +def symbol_range(sym: sp.Symbol) -> str: + return f'{sym.name} ∈ ' + ( + 'ℂ' + if sym.is_complex + else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?')) + ) + + +#################### +# - Symbol Sizes +#################### +class NumberSize2D(enum.StrEnum): + """Simple subset of sizes for rank-2 tensors.""" + + Scalar = enum.auto() + + # Vectors + Vec2 = enum.auto() ## 2x1 + Vec3 = enum.auto() ## 3x1 + Vec4 = enum.auto() ## 4x1 + + # Covectors + CoVec2 = enum.auto() ## 1x2 + CoVec3 = enum.auto() ## 1x3 + CoVec4 = enum.auto() ## 1x4 + + # Square Matrices + Mat22 = enum.auto() ## 2x2 + Mat33 = enum.auto() ## 3x3 + Mat44 = enum.auto() ## 4x4 diff --git a/src/blender_maxwell/utils/sympy_extra/parse_cast.py b/src/blender_maxwell/utils/sympy_extra/parse_cast.py new file mode 100644 index 0000000..37672f9 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/parse_cast.py @@ -0,0 +1,119 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import jax +import jax.numpy as jnp +import sympy as sp + +from .. import logger +from .sympy_type import SympyType + +log = logger.get(__name__) + + +#################### +# - Parsing: Info from SympyType +#################### +def parse_shape(sp_obj: SympyType) -> int | None: + if isinstance(sp_obj, sp.MatrixBase): + return sp_obj.shape + + return None + + +#################### +# - Casting: Python +#################### +def sympy_to_python( + scalar: sp.Basic, use_jax_array: bool = False +) -> int | float | complex | tuple | jax.Array: + """Convert a scalar sympy expression to the directly corresponding Python type. + + Arguments: + scalar: A sympy expression that has no symbols, but is expressed as a Sympy type. + For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter. + + Returns: + A pure Python type that directly corresponds to the input scalar expression. + """ + if isinstance(scalar, sp.MatrixBase): + # Detect Single Column Vector + ## --> Flatten to Single Row Vector + if len(scalar.shape) == 2 and scalar.shape[1] == 1: + _scalar = scalar.T + else: + _scalar = scalar + + # Convert to Tuple of Tuples + matrix = tuple( + [tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()] + ) + + # Detect Single Row Vector + ## --> This could be because the scalar had it. + ## --> This could also be because we flattened a column vector. + ## Either way, we should strip the pointless dimensions. + if len(matrix) == 1: + return matrix[0] if not use_jax_array else jnp.array(matrix[0]) + + return matrix if not use_jax_array else jnp.array(matrix) + if scalar.is_integer: + return int(scalar) + if scalar.is_rational or scalar.is_real: + return float(scalar) + if scalar.is_complex: + return complex(scalar) + + msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001 + raise ValueError(msg) + + +#################### +# - Casting: Printing +#################### +_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter( + settings={ + 'abbrev': True, + } +) + + +def sp_to_str(sp_obj: SympyType) -> str: + """Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter. + + This should be used whenever a **string for UI use** is needed from a `sympy` object. + + Notes: + This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression. + For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation. + + Parameters: + sp_obj: The `sympy` object to convert to a string. + + Returns: + A string representing the expression for human use. + _The string is not re-encodable to the expression._ + """ + ## TODO: A bool flag property that does a lot of find/replace to make it super pretty + return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) + + +def pretty_symbol(sym: sp.Symbol) -> str: + return f'{sym.name} ∈ ' + ( + 'ℤ' + if sym.is_integer + else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?')) + ) diff --git a/src/blender_maxwell/utils/sympy_extra/physical_type.py b/src/blender_maxwell/utils/sympy_extra/physical_type.py new file mode 100644 index 0000000..2adedda --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/physical_type.py @@ -0,0 +1,644 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Implements `PhysicalType`, a convenient, UI-friendly way of deterministically handling the unit-dimensionality of arbitrary objects.""" + +import enum +import functools +import typing as typ + +import sympy.physics.units as spu + +from blender_maxwell import contracts as ct + +from ..staticproperty import staticproperty +from . import units as spux +from .math_type import MathType +from .sympy_expr import Unit +from .sympy_type import SympyType +from .unit_analysis import ( + compare_unit_dim_to_unit_dim_deps, + compare_unit_dims, + unit_to_unit_dim_deps, +) + + +#################### +# - Unit Dimensions +#################### +class DimsMeta(type): + """Metaclass allowing an implementing (ideally empty) class to access `spu.definitions.dimension_definitions` attributes directly via its own attribute.""" + + def __getattr__(cls, attr: str) -> spu.Dimension: + """Alias for `spu.definitions.dimension_definitions.*` (isn't that a mouthful?). + + Raises: + AttributeError: If the name cannot be found. + """ + if ( + attr in spu.definitions.dimension_definitions.__dir__() + and not attr.startswith('__') + ): + return getattr(spu.definitions.dimension_definitions, attr) + + raise AttributeError(name=attr, obj=Dims) + + +class Dims(metaclass=DimsMeta): + """Access `sympy.physics.units` dimensions with less hassle. + + Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`. + + An `AttributeError` is raised if the unit cannot be found in `sympy`. + + Examples: + The objects returned are a direct alias to `sympy`, with less hassle: + ```python + assert Dims.length == ( + sympy.physics.units.definitions.dimension_definitions.length + ) + ``` + """ + + +#################### +# - Physical Type +#################### +class PhysicalType(enum.StrEnum): + """An identifier of unit dimensionality with many useful properties.""" + + # Unitless + NonPhysical = enum.auto() + + # Global + Time = enum.auto() + Angle = enum.auto() + SolidAngle = enum.auto() + ## TODO: Some kind of 3D-specific orientation ex. a quaternion + Freq = enum.auto() + AngFreq = enum.auto() ## rad*hertz + # Cartesian + Length = enum.auto() + Area = enum.auto() + Volume = enum.auto() + # Mechanical + Vel = enum.auto() + Accel = enum.auto() + Mass = enum.auto() + Force = enum.auto() + Pressure = enum.auto() + # Energy + Work = enum.auto() ## joule + Power = enum.auto() ## watt + PowerFlux = enum.auto() ## watt + Temp = enum.auto() + # Electrodynamics + Current = enum.auto() ## ampere + CurrentDensity = enum.auto() + Charge = enum.auto() ## coulomb + Voltage = enum.auto() + Capacitance = enum.auto() ## farad + Impedance = enum.auto() ## ohm + Conductance = enum.auto() ## siemens + Conductivity = enum.auto() ## siemens / length + MFlux = enum.auto() ## weber + MFluxDensity = enum.auto() ## tesla + Inductance = enum.auto() ## henry + EField = enum.auto() + HField = enum.auto() + # Luminal + LumIntensity = enum.auto() + LumFlux = enum.auto() + Illuminance = enum.auto() + + #################### + # - Unit Dimensions + #################### + @functools.cached_property + def unit_dim(self) -> SympyType: + """The unit dimension expression associated with the `PhysicalType`. + + A `PhysicalType` is, in its essence, merely an identifier for a particular unit dimension expression. + """ + PT = PhysicalType + return { + PT.NonPhysical: None, + # Global + PT.Time: Dims.time, + PT.Angle: Dims.angle, + PT.SolidAngle: spu.steradian.dimension, ## MISSING + PT.Freq: Dims.frequency, + PT.AngFreq: Dims.angle * Dims.frequency, + # Cartesian + PT.Length: Dims.length, + PT.Area: Dims.length**2, + PT.Volume: Dims.length**3, + # Mechanical + PT.Vel: Dims.length / Dims.time, + PT.Accel: Dims.length / Dims.time**2, + PT.Mass: Dims.mass, + PT.Force: Dims.force, + PT.Pressure: Dims.pressure, + # Energy + PT.Work: Dims.energy, + PT.Power: Dims.power, + PT.PowerFlux: Dims.power / Dims.length**2, + PT.Temp: Dims.temperature, + # Electrodynamics + PT.Current: Dims.current, + PT.CurrentDensity: Dims.current / Dims.length**2, + PT.Charge: Dims.charge, + PT.Voltage: Dims.voltage, + PT.Capacitance: Dims.capacitance, + PT.Impedance: Dims.impedance, + PT.Conductance: Dims.conductance, + PT.Conductivity: Dims.conductance / Dims.length, + PT.MFlux: Dims.magnetic_flux, + PT.MFluxDensity: Dims.magnetic_density, + PT.Inductance: Dims.inductance, + PT.EField: Dims.voltage / Dims.length, + PT.HField: Dims.current / Dims.length, + # Luminal + PT.LumIntensity: Dims.luminous_intensity, + PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, + PT.Illuminance: Dims.luminous_intensity / Dims.length**2, + }[self] + + @staticproperty + def unit_dims() -> dict[typ.Self, SympyType]: + """All unit dimensions supported by all `PhysicalType`s.""" + return { + physical_type: physical_type.unit_dim + for physical_type in list(PhysicalType) + } + + #################### + # - Convenience Properties + #################### + @functools.cached_property + def default_unit(self) -> list[Unit]: + """Subjective choice of 'default' unit from `self.valid_units`. + + There is no requirement to use this. + """ + PT = PhysicalType + return { + PT.NonPhysical: None, + # Global + PT.Time: spu.picosecond, + PT.Angle: spu.radian, + PT.SolidAngle: spu.steradian, + PT.Freq: spux.terahertz, + PT.AngFreq: spu.radian * spux.terahertz, + # Cartesian + PT.Length: spu.micrometer, + PT.Area: spu.um**2, + PT.Volume: spu.um**3, + # Mechanical + PT.Vel: spu.um / spu.second, + PT.Accel: spu.um / spu.second, + PT.Mass: spu.microgram, + PT.Force: spux.micronewton, + PT.Pressure: spux.millibar, + # Energy + PT.Work: spu.joule, + PT.Power: spu.watt, + PT.PowerFlux: spu.watt / spu.meter**2, + PT.Temp: spu.kelvin, + # Electrodynamics + PT.Current: spu.ampere, + PT.CurrentDensity: spu.ampere / spu.meter**2, + PT.Charge: spu.coulomb, + PT.Voltage: spu.volt, + PT.Capacitance: spu.farad, + PT.Impedance: spu.ohm, + PT.Conductance: spu.siemens, + PT.Conductivity: spu.siemens / spu.micrometer, + PT.MFlux: spu.weber, + PT.MFluxDensity: spu.tesla, + PT.Inductance: spu.henry, + PT.EField: spu.volt / spu.micrometer, + PT.HField: spu.ampere / spu.micrometer, + # Luminal + PT.LumIntensity: spu.candela, + PT.LumFlux: spu.candela * spu.steradian, + PT.Illuminance: spu.candela / spu.meter**2, + }[self] + + #################### + # - Creation + #################### + @staticmethod + def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None: + """Attempt to determine a matching `PhysicalType` from a unit. + + NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`. + + Returns: + The matched `PhysicalType`. + + If none could be matched, then either return `None` (if `optional` is set) or error. + + Raises: + ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. + """ + if unit is None: + return PhysicalType.NonPhysical + + ## TODO_ This enough? + if unit in [spu.radian, spu.degree]: + return PhysicalType.Angle + + unit_dim_deps = unit_to_unit_dim_deps(unit) + if unit_dim_deps is not None: + for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): + if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps): + return physical_type + + if optional: + return None + msg = f'Could not determine PhysicalType for {unit}' + raise ValueError(msg) + + @staticmethod + def from_unit_dim( + unit_dim: SympyType | None, optional: bool = False + ) -> typ.Self | None: + """Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`. + + For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`). + To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions. + + Returns: + The matched `PhysicalType`. + + If none could be matched, then either return `None` (if `optional` is set) or error. + + Raises: + ValueError: If no `PhysicalType` could be matched, and `optional` is `False`. + """ + for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items(): + if compare_unit_dims(unit_dim, candidate_unit_dim): + return physical_type + + if optional: + return None + msg = f'Could not determine PhysicalType for {unit_dim}' + raise ValueError(msg) + + #################### + # - Valid Properties + #################### + @functools.cached_property + def valid_units(self) -> list[Unit]: + """Retrieve an ordered (by subjective usefulness) list of units for this physical type. + + Warnings: + **Altering the order of units hard-breaks backwards compatibility**, since enums based on it only keep an integer index. + + Notes: + The order in which valid units are declared is the exact same order that UI dropdowns display them. + """ + PT = PhysicalType + return { + PT.NonPhysical: [None], + # Global + PT.Time: [ + spu.picosecond, + spux.femtosecond, + spu.nanosecond, + spu.microsecond, + spu.millisecond, + spu.second, + spu.minute, + spu.hour, + spu.day, + ], + PT.Angle: [ + spu.radian, + spu.degree, + ], + PT.SolidAngle: [ + spu.steradian, + ], + PT.Freq: ( + _valid_freqs := [ + spux.terahertz, + spu.hertz, + spux.kilohertz, + spux.megahertz, + spux.gigahertz, + spux.petahertz, + spux.exahertz, + ] + ), + PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs], + # Cartesian + PT.Length: ( + _valid_lens := [ + spu.micrometer, + spu.nanometer, + spu.picometer, + spu.angstrom, + spu.millimeter, + spu.centimeter, + spu.meter, + spu.inch, + spu.foot, + spu.yard, + spu.mile, + ] + ), + PT.Area: [_unit**2 for _unit in _valid_lens], + PT.Volume: [_unit**3 for _unit in _valid_lens], + # Mechanical + PT.Vel: [_unit / spu.second for _unit in _valid_lens], + PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], + PT.Mass: [ + spu.kilogram, + spu.electron_rest_mass, + spu.dalton, + spu.microgram, + spu.milligram, + spu.gram, + spu.metric_ton, + ], + PT.Force: [ + spux.micronewton, + spux.nanonewton, + spux.millinewton, + spu.newton, + spu.kg * spu.meter / spu.second**2, + ], + PT.Pressure: [ + spu.bar, + spux.millibar, + spu.pascal, + spux.hectopascal, + spu.atmosphere, + spu.psi, + spu.mmHg, + spu.torr, + ], + # Energy + PT.Work: [ + spu.joule, + spu.electronvolt, + ], + PT.Power: [ + spu.watt, + ], + PT.PowerFlux: [ + spu.watt / spu.meter**2, + ], + PT.Temp: [ + spu.kelvin, + ], + # Electrodynamics + PT.Current: [ + spu.ampere, + ], + PT.CurrentDensity: [ + spu.ampere / spu.meter**2, + ], + PT.Charge: [ + spu.coulomb, + ], + PT.Voltage: [ + spu.volt, + ], + PT.Capacitance: [ + spu.farad, + ], + PT.Impedance: [ + spu.ohm, + ], + PT.Conductance: [ + spu.siemens, + ], + PT.Conductivity: [ + spu.siemens / spu.micrometer, + spu.siemens / spu.meter, + ], + PT.MFlux: [ + spu.weber, + ], + PT.MFluxDensity: [ + spu.tesla, + ], + PT.Inductance: [ + spu.henry, + ], + PT.EField: [ + spu.volt / spu.micrometer, + spu.volt / spu.meter, + ], + PT.HField: [ + spu.ampere / spu.micrometer, + spu.ampere / spu.meter, + ], + # Luminal + PT.LumIntensity: [ + spu.candela, + ], + PT.LumFlux: [ + spu.candela * spu.steradian, + ], + PT.Illuminance: [ + spu.candela / spu.meter**2, + ], + }[self] + + @functools.cached_property + def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]: + """All shapes with physical meaning in the context of a particular unit dimension.""" + PT = PhysicalType + overrides = { + # Cartesian + PT.Length: [None, (2,), (3,)], + # Mechanical + PT.Vel: [None, (2,), (3,)], + PT.Accel: [None, (2,), (3,)], + PT.Force: [None, (2,), (3,)], + # Energy + PT.Work: [None, (2,), (3,)], + PT.PowerFlux: [None, (2,), (3,)], + # Electrodynamics + PT.CurrentDensity: [None, (2,), (3,)], + PT.MFluxDensity: [None, (2,), (3,)], + PT.EField: [None, (2,), (3,)], + PT.HField: [None, (2,), (3,)], + # Luminal + PT.LumFlux: [None, (2,), (3,)], + } + + return overrides.get(self, [None]) + + @functools.cached_property + def valid_mathtypes(self) -> list[MathType]: + """Returns a list of valid mathematical types, especially whether it can be real- or complex-valued. + + Generally, all unit quantities are real, in the algebraic mathematical sense. + However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening. + This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems. + In general, the value is a phasor. + + While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply. + + Notes: + - **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation. + - **Current**/**Voltage**: The imaginary part represents the phase. + This also holds for any downstream units. + - **Charge**: Generally, it is real. + However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: + - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. + + """ + MT = MathType + PT = PhysicalType + overrides = { + PT.NonPhysical: list(MT), ## Support All + # Cartesian + PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping + # Mechanical + # Energy + # Electrodynamics + PT.Current: [MT.Real, MT.Complex], ## Im -> Phase + PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase + PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase + PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase + PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance + PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction + PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction + PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase + PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase + PT.EField: [MT.Real, MT.Complex], ## Im -> Phase + PT.HField: [MT.Real, MT.Complex], ## Im -> Phase + # Luminal + } + + return overrides.get(self, [MT.Real]) + + #################### + # - UI + #################### + @staticmethod + def to_name(value: typ.Self) -> str: + """A human-readable UI-oriented name for a physical type.""" + if value is PhysicalType.NonPhysical: + return 'Unitless' + return PhysicalType(value).name + + @staticmethod + def to_icon(_: typ.Self) -> str: + """No icons.""" + return '' + + def bl_enum_element(self, i: int) -> ct.BLEnumElement: + """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`.""" + PT = PhysicalType + return ( + str(self), + PT.to_name(self), + PT.to_name(self), + PT.to_icon(self), + i, + ) + + @functools.cached_property + def color(self): + """A color corresponding to the physical type. + + The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented. + The LLM provided the following rationale for its choices: + + > Non-Physical: Grey signifies neutrality and non-physical nature. + > Global: + > Time: Blue is often associated with calmness and the passage of time. + > Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects. + > Frequency and Angular Frequency: Darker shades of blue to maintain the link to time. + > Cartesian: + > Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension. + > Mechanical: + > Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities. + > Mass: Dark red for the fundamental property. + > Force and Pressure: Shades of red indicating intensity. + > Energy: + > Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities. + > Temperature: Yellow for heat. + > Electrodynamics: + > Current and related quantities: Cyan shades indicating flow. + > Voltage, Capacitance: Greenish and blueish cyan for electrical potential. + > Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance. + > Magnetic properties: Magenta shades for magnetism. + > Electric Field: Light blue. + > Magnetic Field: Grey, as it can be considered neutral in terms of direction. + > Luminal: + > Luminous properties: Yellows to signify light and illumination. + > + > This color mapping helps maintain intuitive connections for users interacting with these physical types. + """ + PT = PhysicalType + return { + PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical + # Global + PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time + PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle + PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle + PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency + PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency + # Cartesian + PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length + PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area + PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume + # Mechanical + PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity + PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration + PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass + PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force + PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure + # Energy + PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work + PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power + PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux + PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature + # Electrodynamics + PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current + PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density + PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge + PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage + PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance + PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance + PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance + PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity + PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux + PT.MFluxDensity: ( + 0.85, + 0.5, + 0.85, + 1.0, + ), # Light Magenta: Magnetic Flux Density + PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance + PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field + PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field + # Luminal + PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity + PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux + PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance + }[self] diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_expr.py b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py new file mode 100644 index 0000000..cafd421 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py @@ -0,0 +1,337 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import typing as typ +from fractions import Fraction + +import pydantic as pyd +import sympy as sp +import sympy.physics.units as spu +import typing_extensions as typx +from pydantic_core import core_schema as pyd_core_schema + +from . import units as spux +from .sympy_type import SympyType +from .unit_analysis import get_units, uses_units + + +#################### +# - Pydantic "Sympy Expr" +#################### +class _SympyExpr: + """Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`. + + Notes: + You probably want to use `SympyExpr`. + + Examples: + To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`: + + ```python + SympyExpr = typx.Annotated[SympyType, _SympyExpr] + + class Spam(pyd.BaseModel): + line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3) + ``` + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: SympyType, + _handler: pyd.GetCoreSchemaHandler, + ) -> pyd_core_schema.CoreSchema: + """Compute a schema that allows `pydantic` to validate a `sympy` type.""" + + def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any: + """Parse and validate a string expression. + + Parameters: + sp_str: A stringified `sympy` object, that will be parsed to a sympy type. + Before use, `isinstance(expr_str, str)` is checked. + If the object isn't a string, then the validation will be skipped. + + Returns: + Either a `sympy` object, if the input is parseable, or the same untouched object. + + Raises: + ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. + """ + # Constrain to String + if not isinstance(sp_str, str): + return sp_str + + # Parse String -> Sympy + try: + expr = sp.sympify(sp_str) + except ValueError as ex: + msg = f'String {sp_str} is not a valid sympy expression' + raise ValueError(msg) from ex + + # Substitute Symbol -> Quantity + return expr.subs(spux.UNIT_BY_SYMBOL) + + def validate_from_pytype( + sp_pytype: int | Fraction | float | complex, + ) -> SympyType | typ.Any: + """Parse and validate a pure Python type. + + Parameters: + sp_str: A stringified `sympy` object, that will be parsed to a sympy type. + Before use, `isinstance(expr_str, str)` is checked. + If the object isn't a string, then the validation will be skipped. + + Returns: + Either a `sympy` object, if the input is parseable, or the same untouched object. + + Raises: + ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression. + """ + # Constrain to String + if not isinstance(sp_pytype, int | Fraction | float | complex): + return sp_pytype + + if isinstance(sp_pytype, int): + return sp.Integer(sp_pytype) + if isinstance(sp_pytype, Fraction): + return sp.Rational(sp_pytype.numerator, sp_pytype.denominator) + if isinstance(sp_pytype, float): + return sp.Float(sp_pytype) + + # sp_pytype => Complex + return sp_pytype.real + sp.I * sp_pytype.imag + + sympy_expr_schema = pyd_core_schema.chain_schema( + [ + pyd_core_schema.no_info_plain_validator_function(validate_from_str), + pyd_core_schema.no_info_plain_validator_function(validate_from_pytype), + pyd_core_schema.is_instance_schema(SympyType), + ] + ) + return pyd_core_schema.json_or_python_schema( + json_schema=sympy_expr_schema, + python_schema=sympy_expr_schema, + serialization=pyd_core_schema.plain_serializer_function_ser_schema( + lambda sp_obj: sp.srepr(sp_obj) + ), + ) + + +SympyExpr = typx.Annotated[ + sp.Basic, ## Treat all sympy types as sp.Basic + _SympyExpr, +] +## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate. + + +def ConstrSympyExpr( # noqa: N802, PLR0913 + # Features + allow_variables: bool = True, + allow_units: bool = True, + # Structures + allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']] + | None = None, + allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None, + # Element Class + max_symbols: int | None = None, + allowed_symbols: set[sp.Symbol] | None = None, + allowed_units: set[spu.Quantity] | None = None, + # Shape Class + allowed_matrix_shapes: set[tuple[int, int]] | None = None, +) -> SympyType: + """Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`. + + Relies on the `sympy` assumptions system. + See + + Parameters (TBD): + + Returns: + A type that represents a constrained `sympy` expression. + """ + + def validate_expr(expr: SympyType): + if not (isinstance(expr, SympyType),): + msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})" + raise ValueError(msg) + + msgs = set() + + # Validate Feature Class + if (not allow_variables) and (len(expr.free_symbols) > 0): + msgs.add( + f'allow_variables={allow_variables} does not match expression {expr}.' + ) + if (not allow_units) and uses_units(expr): + msgs.add(f'allow_units={allow_units} does not match expression {expr}.') + + # Validate Structure Class + if ( + allowed_sets + and isinstance(expr, sp.Expr) + and not any( + { + 'integer': expr.is_integer, + 'rational': expr.is_rational, + 'real': expr.is_real, + 'complex': expr.is_complex, + }[allowed_set] + for allowed_set in allowed_sets + ) + ): + msgs.add( + f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" + ) + if allowed_structures and not any( + { + 'scalar': True, + 'matrix': isinstance(expr, sp.MatrixBase), + }[allowed_set] + for allowed_set in allowed_structures + ): + msgs.add( + f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))" + ) + + # Validate Element Class + if max_symbols and len(expr.free_symbols) > max_symbols: + msgs.add(f'max_symbols={max_symbols} does not match expression {expr}') + if allowed_symbols and expr.free_symbols.issubset(allowed_symbols): + msgs.add( + f'allowed_symbols={allowed_symbols} does not match expression {expr}' + ) + if allowed_units and get_units(expr).issubset(allowed_units): + msgs.add(f'allowed_units={allowed_units} does not match expression {expr}') + + # Validate Shape Class + if ( + allowed_matrix_shapes and isinstance(expr, sp.MatrixBase) + ) and expr.shape not in allowed_matrix_shapes: + msgs.add( + f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}' + ) + + # Error or Return + if msgs: + raise ValueError(str(msgs)) + return expr + + return typx.Annotated[ + sp.Basic, + _SympyExpr, + pyd.AfterValidator(validate_expr), + ] + + +#################### +# - Common ConstrSympyExpr +#################### +# Expression +ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_structures={'scalar'}, + allowed_sets={'integer', 'rational', 'real'}, +) +ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_structures={'scalar'}, + allowed_sets={'integer', 'rational', 'real', 'complex'}, +) + +# Symbol +IntSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer'}, + max_symbols=1, +) +RationalSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational'}, + max_symbols=1, +) +RealSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + max_symbols=1, +) +ComplexSymbol: typ.TypeAlias = ConstrSympyExpr( + allow_variables=True, + allow_units=False, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + max_symbols=1, +) +Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol + +# Unit +UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension + +## Technically a "unit expression", which includes compound types. +## Support for this is the reason to prefer over raw spu.Quantity. +Unit: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_structures={'scalar'}, +) + +# Number +IntNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer'}, + allowed_structures={'scalar'}, +) +RealNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'scalar'}, +) +ComplexNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + allowed_structures={'scalar'}, +) +Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber + +# Number +PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'scalar'}, +) +PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=True, + allowed_sets={'integer', 'rational', 'real', 'complex'}, + allowed_structures={'scalar'}, +) +PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber + +# Vector +Real3DVector: typ.TypeAlias = ConstrSympyExpr( + allow_variables=False, + allow_units=False, + allowed_sets={'integer', 'rational', 'real'}, + allowed_structures={'matrix'}, + allowed_matrix_shapes={(3, 1)}, +) diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_type.py b/src/blender_maxwell/utils/sympy_extra/sympy_type.py new file mode 100644 index 0000000..ecb736e --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/sympy_type.py @@ -0,0 +1,23 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import sympy as sp +import sympy.physics.units as spu + +#################### +# - Underlying "Sympy Type" +#################### +SympyType = sp.Basic | sp.MatrixBase | spu.Quantity | spu.Dimension diff --git a/src/blender_maxwell/utils/sympy_extra/unit_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py new file mode 100644 index 0000000..3234407 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py @@ -0,0 +1,287 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Functions for characterizaiton, conversion and casting of `sympy` objects that use units.""" + +import functools + +import sympy as sp +import sympy.physics.units as spu + +from . import units as spux +from .parse_cast import sympy_to_python +from .sympy_type import SympyType + + +#################### +# - Unit Characterization +#################### +## TODO: Caching w/srepr'ed expression. +## TODO: An LFU cache could do better than an LRU. +def uses_units(sp_obj: SympyType) -> bool: + """Determines if an expression uses any units. + + Parameters: + expr: The sympy object that may contain units. + + Returns: + Whether or not units are present in the object. + """ + return sp_obj.has(spu.Quantity) + + +## TODO: Caching w/srepr'ed expression. +## TODO: An LFU cache could do better than an LRU. +def get_units(expr: sp.Expr) -> set[spu.Quantity]: + """Finds all units used by the expression, and returns them as a set. + + No information about _the relationship between units_ is exposed. + For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`. + + + Notes: + The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements. + + The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node. + + Parameters: + expr: The sympy expression that may contain units. + + Returns: + All units (`spu.Quantity`) used within the expression. + """ + return { + subexpr + for subexpr in sp.postorder_traversal(expr) + if isinstance(subexpr, spu.Quantity) + } + + +#################### +# - Dimensional Characterization +#################### +def unit_dim_to_unit_dim_deps( + unit_dims: SympyType, +) -> dict[spu.dimensions.Dimension, int] | None: + """Normalize an expression to a mapping of its dimensional dependencies. + + Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent. + + Notes: + We adhere to SI unit conventions when determining dimensional dependencies, to ensure that ex. `freq -> 1/time` equivalences are normalized away. + This allows the output of this method to be compared meaningfully, to determine whether two dimensional expressions are equivalent. + + We choose to catch a `TypeError`, for cases where dimensional analysis is impossible (especially `+` or `-` between differing dimensions). + This may have a slight performance penalty. + + Returns: + The dimensional dependencies of the dimensional expression. + + If such a thing makes no sense, ex. if `+` or `-` is present between differing unit dimensions, then return None. + """ + dimsys_SI = spu.systems.si.dimsys_SI + + # Retrieve Dimensional Dependencies + try: + return dimsys_SI.get_dimensional_dependencies(unit_dims) + + # Catch TypeError + ## -> Happens if `+` or `-` is in `unit`. + ## -> Generally, it doesn't make sense to add/subtract differing unit dims. + ## -> Thus, when trying to figure out the unit dimension, there isn't one. + except TypeError: + return None + + +def unit_to_unit_dim_deps( + unit: SympyType, +) -> dict[spu.dimensions.Dimension, int] | None: + """Deduce the dimensional dependencies of a unit. + + Notes: + Using `.subs()` to replace `sp.Quantity`s with `spu.dimensions.Dimension`s seems to result in an expression that absolutely refuses to claim that it has anything other than raw `sp.Symbol`s. + + This is extremely problematic - dimensional analysis relies on the arithmetic properties of proper `Dimension` objects. + + For this reason, though we'd rather have a `unit_to_unit_dims()` function, we have not yet found a way to do this. + Luckily, most of our uses cases seem only to require the dimensional dictionary, which (surprisingly) seems accessible using `unit_dim_to_unit_dim_deps()`. + + """ + # Retrieve Dimensional Dependencies + ## -> NOTE: .subs() alone seems to produce sp.Symbol atoms. + ## -> This is extremely problematic; `Dims` arithmetic has key properties. + ## -> So we have to go all the way to the dimensional dependencies. + ## -> This isn't really respecting the args, but it seems to work :) + return unit_dim_to_unit_dim_deps( + unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)}) + ) + + +def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool: + """Compare the dimensional dependencies of two unit dimensions. + + Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent. + """ + return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps( + unit_dim_r + ) + + +def compare_units_by_unit_dims(unit_l: SympyType, unit_r: SympyType) -> bool: + """Compare two units by their unit dimensions.""" + return unit_to_unit_dim_deps(unit_l) == unit_to_unit_dim_deps(unit_r) + + +def compare_unit_dim_to_unit_dim_deps( + unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int] +) -> bool: + """Compare the dimensional dependencies of unit dimensions to pre-defined unit dimensions.""" + return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps + + +#################### +# - Unit Casting +#################### +def strip_units(sp_obj: SympyType) -> SympyType: + """Strip all units by replacing them to `1`. + + This is a rather unsafe method. + You probably shouldn't use it. + + Warnings: + Absolutely no effort is made to determine whether stripping units is a _meaningful thing to do_. + + For example, using `+` expressions of compatible dimension, but different units, is a clear mistake. + For example, `8*meter + 9*millimeter` strips to `8(1) + 9(1) = 17`, which is a garbage result. + + The **user of this method** must themselves perform appropriate checks on th eobject before stripping units. + + Parameters: + sp_obj: A sympy object that contains unit symbols. + **NOTE**: Unit symbols (from `sympy.physics.units`) are not _free_ symbols, in that they are not unknown. + Nonetheless, they are not _numbers_ either, and thus they cannot be used in a numerical expression. + + Returns: + The sympy object with all unit symbols replaced by `1`, effectively extracting the unitless part of the object. + """ + return sp_obj.subs(spux.UNIT_TO_1) + + +def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType: + """Convert a sympy object to the given unit. + + Supports a unit of `None`, which simply causes the object to have its units stripped. + """ + if unit is None: + return strip_units(sp_obj) + return spu.convert_to(sp_obj, unit) + + # msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"' + # raise ValueError(msg) + + +## TODO: Include sympy_to_python in 'scale_to' to match semantics of 'scale_to_unit_system' +## -- Introduce a 'strip_unit +def scale_to_unit( + sp_obj: SympyType, + unit: spu.Quantity | None, + cast_to_pytype: bool = False, + use_jax_array: bool = False, +) -> SympyType: + """Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value. + + This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**. + + Notes: + The unitless output is still an `sp.Expr`, which may contain ex. symbols. + + If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type. + In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem. + + Parameters: + expr: The unit-containing expression to convert. + unit_to: The unit that is converted to. + + Returns: + The unitless part of `expr`, after scaling the entire expression to `unit`. + + Raises: + ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`. + """ + sp_obj_stripped = strip_units(convert_to_unit(sp_obj, unit)) + if cast_to_pytype: + return sympy_to_python( + sp_obj_stripped, + use_jax_array=use_jax_array, + ) + return sp_obj_stripped + + +def scaling_factor( + unit_from: SympyType, unit_to: SympyType +) -> int | float | complex | tuple | None: + """Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another. + + Parameters: + unit_from: The unit that is converted from. + unit_to: The unit that is converted to. + + Returns: + The numerical scaling factor between the two units. + + If the units are incompatible, then we return None. + + Raises: + ValueError: If the two units don't share a common dimension. + """ + if compare_units_by_unit_dims(unit_from, unit_to): + return scale_to_unit(unit_from, unit_to) + return None + + +@functools.cache +def unit_str_to_unit(unit_str: str, optional: bool = False) -> SympyType | None: + """Determine the `sympy` unit expression that matches the given unit string. + + Parameters: + unit_str: A string parseable with `sp.sympify`, which contains a unit expression. + optional: Whether to return + **NOTE**: `None` is itself a valid "unit", denoting dimensionlessness, in general. + Ensure that appropriate checks are performed to account for this nuance. + + Returns: + The matching `sympy` unit. + + Raises: + ValueError: When no valid unit can be matched to the unit string, and `optional` is `False`. + """ + match unit_str: + # Special-Case 'degree' + ## -> sp.sympify('degree') produces the sp.degree(). + ## -> TODO: Proper Analysis analysis. + case 'degree': + unit = spu.degree + + case _: + unit = sp.sympify(unit_str).subs(spux.UNIT_BY_SYMBOL) + + if uses_units(unit): + return unit + + if optional: + return None + msg = f'No valid unit for unit string {unit_str}' + raise ValueError(msg) diff --git a/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py new file mode 100644 index 0000000..cb30375 --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py @@ -0,0 +1,93 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Functions for conversion and casting of `sympy` objects that use units, via unit systems.""" + +import jax +import sympy.physics.units as spu + +from . import units as spux +from .parse_cast import sympy_to_python +from .physical_type import PhysicalType +from .sympy_type import SympyType +from .unit_analysis import get_units +from .unit_systems import UnitSystem + + +#################### +# - Conversion +#################### +def strip_unit_system( + sp_obj: SympyType, unit_system: UnitSystem | None = None +) -> SympyType: + """Strip units occurring in the given unit system from the expression. + + Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". + Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_. + + Notes: + You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. + """ + if unit_system is None: + return sp_obj.subs(spux.UNIT_TO_1) + + return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) + + +def convert_to_unit_system( + sp_obj: SympyType, unit_system: UnitSystem | None +) -> SympyType: + """Convert an expression to the units of a given unit system.""" + if unit_system is None: + return sp_obj + + return spu.convert_to( + sp_obj, + {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, + ) + + +#################### +# - Casting +#################### +def scale_to_unit_system( + sp_obj: SympyType, + unit_system: UnitSystem | None, + use_jax_array: bool = False, +) -> int | float | complex | tuple | jax.Array: + """Convert an expression to the units of a given unit system, then strip all units of the unit system. + + Afterwards, it is converted to an appropriate Python type. + + Notes: + For stability, and performance, reasons, this should only be used at the very last stage. + + Regarding performance: **This is not a fast function**. + + Parameters: + sp_obj: An arbitrary sympy object, presumably with units. + unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units. + Note that, in this context, only `unit_system.values()` is used. + + Returns: + An appropriate pure Python type, after scaling to the unit system and stripping all units away. + + If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`. + """ + return sympy_to_python( + strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system), + use_jax_array=use_jax_array, + ) diff --git a/src/blender_maxwell/utils/sympy_extra/unit_systems.py b/src/blender_maxwell/utils/sympy_extra/unit_systems.py new file mode 100644 index 0000000..1de298a --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/unit_systems.py @@ -0,0 +1,80 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Defines a common unit system representation, as well as a few of the most common / useful unit systems. + +Attributes: + UnitSystem: Type of a unit system representation, as an exhaustive mapping from `PhysicalType` to a unit expression. + **Compatibility between `PhysicalType` and unit must be manually guaranteed** when defining new unit systems. + UNITS_SI: Pre-defined go-to choice of unit system, which can also be a useful base to build other unit systems on. +""" + +import typing as typ + +import sympy.physics.units as spu + +from . import units as spux +from .physical_type import PhysicalType as PT # noqa: N817 +from .sympy_expr import Unit + +#################### +# - Unit System Representation +#################### +UnitSystem: typ.TypeAlias = dict[PT, Unit] + +#################### +# - Standard Unit Systems +#################### +UNITS_SI: UnitSystem = { + PT.NonPhysical: None, + # Global + PT.Time: spu.second, + PT.Angle: spu.radian, + PT.SolidAngle: spu.steradian, + PT.Freq: spu.hertz, + PT.AngFreq: spu.radian * spu.hertz, + # Cartesian + PT.Length: spu.meter, + PT.Area: spu.meter**2, + PT.Volume: spu.meter**3, + # Mechanical + PT.Vel: spu.meter / spu.second, + PT.Accel: spu.meter / spu.second**2, + PT.Mass: spu.kilogram, + PT.Force: spu.newton, + # Energy + PT.Work: spu.joule, + PT.Power: spu.watt, + PT.PowerFlux: spu.watt / spu.meter**2, + PT.Temp: spu.kelvin, + # Electrodynamics + PT.Current: spu.ampere, + PT.CurrentDensity: spu.ampere / spu.meter**2, + PT.Voltage: spu.volt, + PT.Capacitance: spu.farad, + PT.Impedance: spu.ohm, + PT.Conductance: spu.siemens, + PT.Conductivity: spu.siemens / spu.meter, + PT.MFlux: spu.weber, + PT.MFluxDensity: spu.tesla, + PT.Inductance: spu.henry, + PT.EField: spu.volt / spu.meter, + PT.HField: spu.ampere / spu.meter, + # Luminal + PT.LumIntensity: spu.candela, + PT.LumFlux: spux.lumen, + PT.Illuminance: spu.lux, +} diff --git a/src/blender_maxwell/utils/sympy_extra/units.py b/src/blender_maxwell/utils/sympy_extra/units.py new file mode 100644 index 0000000..9ffd4cf --- /dev/null +++ b/src/blender_maxwell/utils/sympy_extra/units.py @@ -0,0 +1,77 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import typing as typ + +import sympy as sp +import sympy.physics.units as spu + +#################### +# - Units +#################### +# Time +femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs') +femtosecond.set_global_relative_scale_factor(spu.femto, spu.second) + +# Length +femtometer = fm = spu.Quantity('femtometer', abbrev='fm') +femtometer.set_global_relative_scale_factor(spu.femto, spu.meter) + +# Lum Flux +lumen = lm = spu.Quantity('lumen', abbrev='lm') +lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian) + +# Force +nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816 +nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton) + +micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816 +micronewton.set_global_relative_scale_factor(spu.micro, spu.newton) + +millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816 +micronewton.set_global_relative_scale_factor(spu.milli, spu.newton) + +# Frequency +kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz') +kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) + +megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz') +kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz) + +gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz') +gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz) + +terahertz = THz = spu.Quantity('terahertz', abbrev='THz') +terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz) + +petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz') +petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz) + +exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz') +exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz) + +# Pressure +millibar = mbar = spu.Quantity('millibar', abbrev='mbar') +millibar.set_global_relative_scale_factor(spu.milli, spu.bar) + +hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816 +hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal) + +UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { + unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) +} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} + +UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}