feat: deep refactors / fixes

We refactored the entire `extra_sympy_units` into a (rather decently sized)
proper package, and changed its name.
Additionally, we've seperated the operation enums from the math nodes
themselves and moved them to dedicated `math_system` package, which is a
very big breath of fresh air.
I'll need a moment to fix a few typos, but incredibly, this architecture
is kind of "just works" (TM) at this point - not like the `FlowKind`
refactoring debacle...

To go with that, we've greatly streamlined domain handling in
`SimSymbol`.
We now track a symbolic set as the domain, which is very simple and
effective, but comes with some headaches too (`sympy` is fighting me,
grumble grumble...)
We've also managed to enforce a few unit-conversions in the operate
node, and revamp the entire operation validity detection (weirdly hard).

The big pain point at the moment is determining the image of functions
we apply, on the `sp.Set` domain of `SimSymbol`s. We can express this
trivially, but `sympy` simply doesn't care to evaluate it.

One can use `SetExpr` and `AccumBounds` to "sometimes work" for some set
kinds, but nothing nearly good enough for even our relatively humble
needs.
Let alone stuff like fourier.

We've ended up deciding to hard-code this part of the process by-operation.
With domain-specific knowledge and little bit of suffering, we can
manually ensure the output domain of every operation makes it to the
output symbol.

As for "why bother", well, the entire premise of a symbolic nodal math system
that is tolerable to use, requires checking the valid domain of the input.
We do wish it were optional, but eh.
main
Sofus Albert Høgsbro Rose 2024-06-01 19:08:43 +02:00
parent 572d53f41e
commit b51c4f1889
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
66 changed files with 4449 additions and 3275 deletions

View File

@ -21,7 +21,7 @@ import typing as typ
import bpy
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .socket_types import SocketType

View File

@ -22,7 +22,7 @@ import numpy as np
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)

View File

@ -16,7 +16,7 @@
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from . import FlowKind

View File

@ -19,7 +19,7 @@ import functools
import typing as typ
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty

View File

@ -18,8 +18,8 @@ import dataclasses
import functools
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .array import ArrayFlow
from .lazy_range import RangeFlow
@ -89,6 +89,7 @@ class InfoFlow:
return None
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
"""Retrieve the dimension associated with a particular index."""
if idx > 0 and idx < len(self.dims) - 1:
return list(self.dims.keys())[idx]
return None
@ -179,15 +180,23 @@ class InfoFlow:
While that sounds fancy and all, it boils down to:
$$
\texttt{dims} + |\texttt{output}.\texttt{shape}|
|\texttt{dims}| + |\texttt{output}.\texttt{shape}|
$$
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape.
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
return len(self.input_mathtypes) + self.output_shape_len
return len(self.dims) + self.output_shape_len
@functools.cached_property
def is_scalar(self) -> tuple[spux.MathType, int, int]:
"""Whether the described object can be described as "scalar".
True when `self.order == 0`.
"""
return self.order == 0
####################
# - Properties
@ -204,6 +213,58 @@ class InfoFlow:
for dim, dim_idx in self.dims.items()
}
####################
# - Operations: Comparison
####################
def compare_dims_identical(self, other: typ.Self) -> bool:
"""Whether that the quantity and properites of all dimension `SimSymbol`s are "identical".
"Identical" is defined according to the semantics of `SimSymbol.compare()`, which generally means that everything but the exact name and unit are different.
"""
return len(self.dims) == len(other.dims) and all(
dim_l.compare(dim_r)
for dim_l, dim_r in zip(self.dims, other.dims, strict=True)
)
def compare_addable(
self, other: typ.Self, allow_differing_unit: bool = False
) -> bool:
"""Whether the two `InfoFlows` can be added/subtracted elementwise.
Parameters:
allow_differing_unit: When set,
Forces the user to be explicit about specifying
"""
return self.compare_dims_identical(other) and self.output.compare_addable(
other.output, allow_differing_unit=allow_differing_unit
)
def compare_multiplicable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be multiplied (elementwise).
- The output `SimSymbol`s must be able to be multiplied.
- Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical.
"""
return self.output.compare_multiplicable(other.output) and (
(len(self.dims) == 0 and self.output.shape_len == 0)
or (len(other.dims) == 0 and other.output.shape_len == 0)
or self.compare_dims_identical(other)
)
def compare_exponentiable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be exponentiated.
In general, we follow the rules of the "Hadamard Power" operator, which is also in use in `numpy` broadcasting rules.
- The output `SimSymbol`s must be able to be exponentiated (mainly, the exponent can't have a unit).
- Either the LHS is a scalar, the RHS is a scalar, or the dimensions are identical.
"""
return self.output.compare_exponentiable(other.output) and (
(len(self.dims) == 0 and self.output.shape_len == 0)
or (len(other.dims) == 0 and other.output.shape_len == 0)
or self.compare_dims_identical(other)
)
####################
# - Operations: Dimensions
####################
@ -319,6 +380,7 @@ class InfoFlow:
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
) -> spux.SympyExpr:
"""Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
sym_name = sim_symbols.SimSymbolName.Expr
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
@ -341,11 +403,11 @@ class InfoFlow:
cols = self.output.cols
match (rows, cols):
case (1, 1):
new_output = self.output.set_size(len(last_idx), 1)
new_output = self.output.update(rows=len(last_idx), cols=1)
case (_, 1):
new_output = self.output.set_size(rows, len(last_idx))
new_output = self.output.update(rows=rows, cols=len(last_idx))
case (1, _):
new_output = self.output.set_size(len(last_idx), cols)
new_output = self.output.update(rows=len(last_idx), cols=cols)
case (_, _):
raise NotImplementedError ## Not yet :)

View File

@ -226,7 +226,7 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
@ -335,7 +335,7 @@ class FuncFlow(pyd.BaseModel):
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
"""
if self.supports_jax:
if self.supports_jax and not disallow_jax:
return self.func_jax(
*params.scaled_func_args(symbol_values),
**params.scaled_func_kwargs(symbol_values),

View File

@ -24,8 +24,8 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .array import ArrayFlow
@ -95,10 +95,12 @@ class RangeFlow(pyd.BaseModel):
stop: spux.ScalarUnitlessRealExpr
steps: int = 0
scaling: ScalingMode = ScalingMode.Lin
## TODO: No support for non-Lin (yet)
unit: spux.Unit | None = None
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
## TODO: No proper support for symbols (yet)
# Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessRealExpr | None = None
@ -112,18 +114,26 @@ class RangeFlow(pyd.BaseModel):
steps: int = 50,
scaling: ScalingMode | str = ScalingMode.Lin,
) -> typ.Self:
if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
use_steps = 0
else:
use_steps = steps
if (
sym.mathtype is not spux.MathType.Complex
and sym.rows == 1
and sym.cols == 1
):
if sym.domain.inf.is_infinite or sym.domain.sup.is_infinite:
_steps = 0
else:
_steps = steps
return RangeFlow(
start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
steps=use_steps,
scaling=ScalingMode(scaling),
unit=sym.unit,
)
return RangeFlow(
start=sym.domain.inf if sym.domain.inf.is_finite else sp.S(-1),
stop=sym.domain.sup if sym.domain.sup.is_finite else sp.S(1),
steps=_steps,
scaling=ScalingMode(scaling),
unit=sym.unit,
)
msg = f'RangeFlow is incompatible with SimSymbol {sym}'
raise ValueError(msg)
def to_sym(
self,

View File

@ -23,7 +23,7 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
@ -53,11 +53,6 @@ class ParamsFlow(pyd.BaseModel):
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = pyd.Field(default_factory=dict)
@functools.cached_property
def diff_symbols(self) -> set[sim_symbols.SimSymbol]:
"""Set of all unrealized `SimSymbol`s that can act as inputs when differentiating the function for which this `ParamsFlow` tracks arguments."""
return {sym for sym in self.symbols if sym.can_diff}
####################
# - Symbols
####################

View File

@ -29,7 +29,7 @@ import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow

View File

@ -28,7 +28,7 @@ import typing as typ
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
####################
# - Unit Systems

View File

@ -20,7 +20,7 @@ import typing as typ
import bpy
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import bl_socket_map

View File

@ -0,0 +1,29 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from .filter import FilterOperation
from .map import MapOperation
from .operate import BinaryOperation
from .reduce import ReduceOperation
from .transform import TransformOperation
__all__ = [
'FilterOperation',
'MapOperation',
'BinaryOperation',
'ReduceOperation',
'TransformOperation',
]

View File

@ -0,0 +1,233 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import jax.lax as jlax
import jax.numpy as jnp
from blender_maxwell.utils import logger, sim_symbols
from .. import contracts as ct
log = logger.get(__name__)
class FilterOperation(enum.StrEnum):
"""Valid operations for the `FilterMathNode`.
Attributes:
DimToVec: Shift last dimension to output.
DimsToMat: Shift last 2 dimensions to output.
PinLen1: Remove a len(1) dimension.
Pin: Remove a len(n) dimension by selecting a particular index.
Swap: Swap the positions of two dimensions.
"""
# Slice
Slice = enum.auto()
SliceIdx = enum.auto()
# Pin
PinLen1 = enum.auto()
Pin = enum.auto()
PinIdx = enum.auto()
# Dimension
Swap = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
FO = FilterOperation
return {
# Slice
FO.Slice: '≈a[v₁:v₂]',
FO.SliceIdx: '=a[i:j]',
# Pin
FO.PinLen1: 'a[0] → a',
FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'a[i] → a',
# Reinterpret
FO.Swap: 'a₁ ↔ a₂',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
FO = FilterOperation
return (
str(self),
FO.to_name(self),
FO.to_name(self),
FO.to_icon(self),
i,
)
####################
# - Ops from Info
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
operations = []
# Slice
if info.dims:
operations.append(FO.SliceIdx)
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1)
## Pin | PinIdx
## -> There must be a dimension, full stop.
if info.dims:
operations += [FO.Pin, FO.PinIdx]
# Reinterpret
## Swap
## -> There must be at least two dimensions.
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
return operations
####################
# - Computed Properties
####################
@property
def func_args(self) -> list[sim_symbols.SimSymbol]:
FO = FilterOperation
return {
# Pin
FO.Pin: [sim_symbols.idx(None)],
FO.PinIdx: [sim_symbols.idx(None)],
}.get(self, [])
####################
# - Methods
####################
@property
def num_dim_inputs(self) -> None:
FO = FilterOperation
return {
# Slice
FO.Slice: 1,
FO.SliceIdx: 1,
# Pin
FO.PinLen1: 1,
FO.Pin: 1,
FO.PinIdx: 1,
# Reinterpret
FO.Swap: 2,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
match self:
# Slice
case FO.Slice:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
case FO.SliceIdx:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
# Pin
case FO.PinLen1:
return [
dim
for dim, dim_idx in info.dims.items()
if not info.has_idx_cont(dim) and len(dim_idx) == 1
]
case FO.Pin:
return info.dims
case FO.PinIdx:
return [dim for dim in info.dims if not info.has_idx_cont(dim)]
# Dimension
case FO.Swap:
return info.dims
return []
def are_dims_valid(
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
if self.num_dim_inputs == 1:
return dim_0 in self.valid_dims(info)
if self.num_dim_inputs == 2: # noqa: PLR2004
valid_dims = self.valid_dims(info)
return dim_0 in valid_dims and dim_1 in valid_dims
return False
####################
# - UI
####################
def jax_func(
self,
axis_0: int | None,
axis_1: int | None,
slice_tuple: tuple[int, int, int] | None = None,
):
FO = FilterOperation
return {
# Pin
FO.Slice: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
# Pin
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
}[self]
def transform_info(
self,
info: ct.InfoFlow,
dim_0: sim_symbols.SimSymbol,
dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
FO = FilterOperation
return {
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
}[self]()

View File

@ -0,0 +1,365 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
log = logger.get(__name__)
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
Attributes:
Real: Compute the real part of the input.
Imag: Compute the imaginary part of the input.
Abs: Compute the absolute value of the input.
Sq: Square the input.
Sqrt: Compute the (principal) square root of the input.
InvSqrt: Compute the inverse square root of the input.
Cos: Compute the cosine of the input.
Sin: Compute the sine of the input.
Tan: Compute the tangent of the input.
Acos: Compute the inverse cosine of the input.
Asin: Compute the inverse sine of the input.
Atan: Compute the inverse tangent of the input.
Norm2: Compute the 2-norm (aka. length) of the input vector.
Det: Compute the determinant of the input matrix.
Cond: Compute the condition number of the input matrix.
NormFro: Compute the frobenius norm of the input matrix.
Rank: Compute the rank of the input matrix.
Diag: Compute the diagonal vector of the input matrix.
EigVals: Compute the eigenvalues vector of the input matrix.
SvdVals: Compute the singular values vector of the input matrix.
Inv: Compute the inverse matrix of the input matrix.
Tra: Compute the transpose matrix of the input matrix.
Qr: Compute the QR-factorized matrices of the input matrix.
Chol: Compute the Cholesky-factorized matrices of the input matrix.
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
# By Number
Real = enum.auto()
Imag = enum.auto()
Abs = enum.auto()
Sq = enum.auto()
Sqrt = enum.auto()
InvSqrt = enum.auto()
Cos = enum.auto()
Sin = enum.auto()
Tan = enum.auto()
Acos = enum.auto()
Asin = enum.auto()
Atan = enum.auto()
Sinc = enum.auto()
# By Vector
Norm2 = enum.auto()
# By Matrix
Det = enum.auto()
Cond = enum.auto()
NormFro = enum.auto()
Rank = enum.auto()
Diag = enum.auto()
EigVals = enum.auto()
SvdVals = enum.auto()
Inv = enum.auto()
Tra = enum.auto()
Qr = enum.auto()
Chol = enum.auto()
Svd = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
MO = MapOperation
return {
# By Number
MO.Real: '(v)',
MO.Imag: 'Im(v)',
MO.Abs: '|v|',
MO.Sq: '',
MO.Sqrt: '√v',
MO.InvSqrt: '1/√v',
MO.Cos: 'cos v',
MO.Sin: 'sin v',
MO.Tan: 'tan v',
MO.Acos: 'acos v',
MO.Asin: 'asin v',
MO.Atan: 'atan v',
MO.Sinc: 'sinc v',
# By Vector
MO.Norm2: '||v||₂',
# By Matrix
MO.Det: 'det V',
MO.Cond: 'κ(V)',
MO.NormFro: '||V||_F',
MO.Rank: 'rank V',
MO.Diag: 'diag V',
MO.EigVals: 'eigvals V',
MO.SvdVals: 'svdvals V',
MO.Inv: 'V⁻¹',
MO.Tra: 'Vt',
MO.Qr: 'qr V',
MO.Chol: 'chol V',
MO.Svd: 'svd V',
}[value]
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
MO = MapOperation
return (
str(self),
MO.to_name(self),
MO.to_name(self),
MO.to_icon(self),
i,
)
####################
# - Ops from Shape
####################
@staticmethod
def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation
element_ops = [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match (info.output.rows, info.output.cols):
case (1, 1):
return element_ops
case (_, 1):
return [*element_ops, MO.Norm2]
case (rows, cols) if rows == cols:
## TODO: Check hermitian/posdef for cholesky.
## - Can we even do this with just the output symbol approach?
return [
*element_ops,
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
return []
####################
# - Function Properties
####################
@property
def sp_func(self):
MO = MapOperation
return {
# By Number
MO.Real: lambda expr: sp.re(expr),
MO.Imag: lambda expr: sp.im(expr),
MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
MO.Tan: lambda expr: sp.tan(expr),
MO.Acos: lambda expr: sp.acos(expr),
MO.Asin: lambda expr: sp.asin(expr),
MO.Atan: lambda expr: sp.atan(expr),
MO.Sinc: lambda expr: sp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0],
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: sp.det(expr),
MO.Cond: lambda expr: expr.condition_number(),
MO.NormFro: lambda expr: expr.norm(ord='fro'),
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
MO.Qr: lambda expr: expr.QRdecomposition(),
MO.Chol: lambda expr: expr.cholesky(),
MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property
def jax_func(self):
MO = MapOperation
return {
# By Number
MO.Real: lambda expr: jnp.real(expr),
MO.Imag: lambda expr: jnp.imag(expr),
MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda expr: jnp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda expr: jnp.sin(expr),
MO.Tan: lambda expr: jnp.tan(expr),
MO.Acos: lambda expr: jnp.acos(expr),
MO.Asin: lambda expr: jnp.asin(expr),
MO.Atan: lambda expr: jnp.atan(expr),
MO.Sinc: lambda expr: jnp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: jnp.linalg.det(expr),
MO.Cond: lambda expr: jnp.linalg.cond(expr),
MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
# Matrix -> Vec
MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
# Matrix -> Matrix
MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices
MO.Qr: lambda expr: jnp.linalg.qr(expr),
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self]
def transform_info(self, info: ct.InfoFlow):
MO = MapOperation
return {
# By Number
MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector
MO.Norm2: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# By Matrix
MO.Det: lambda: info.update_output(
rows=1,
cols=1,
),
MO.Cond: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
),
MO.NormFro: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
MO.Rank: lambda: info.update_output(
mathtype=spux.MathType.Integer,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Matrix -> Vector ## TODO: ALL OF THESE
MO.Diag: lambda: info,
MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()

View File

@ -0,0 +1,476 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import jax.numpy as jnp
import sympy as sp
import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
log = logger.get(__name__)
def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
"""Implement the Hadamard Power.
Follows the specification in <https://docs.sympy.org/latest/modules/matrices/expressions.html#sympy.matrices.expressions.HadamardProduct>, which also conforms to `numpy` broadcasting rules for `**` on `np.ndarray`.
"""
match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
case (False, False):
msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}"
raise ValueError(msg)
case (True, False):
return lhs.applyfunc(lambda el: el**rhs)
case (False, True):
return rhs.applyfunc(lambda el: lhs**el)
case (True, True) if lhs.shape == rhs.shape:
common_shape = lhs.shape
return sp.ImmutableMatrix(
*common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
)
case _:
msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
raise ValueError(msg)
class BinaryOperation(enum.StrEnum):
"""Valid operations for the `OperateMathNode`.
Attributes:
Mul: Scalar multiplication.
Div: Scalar division.
Pow: Scalar exponentiation.
Add: Elementwise addition.
Sub: Elementwise subtraction.
HadamMul: Elementwise multiplication (hadamard product).
HadamPow: Principled shape-aware exponentiation (hadamard power).
Atan2: Quadrant-respecting 2D arctangent.
VecVecDot: Dot product for identically shaped vectors w/transpose.
Cross: Cross product between identically shaped 3D vectors.
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system.
VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
"""
# Number | Number
Mul = enum.auto()
Div = enum.auto()
Pow = enum.auto()
# Elements | Elements
Add = enum.auto()
Sub = enum.auto()
HadamMul = enum.auto()
HadamPow = enum.auto()
HadamDiv = enum.auto()
Atan2 = enum.auto()
# Vector | Vector
VecVecDot = enum.auto()
Cross = enum.auto()
VecVecOuter = enum.auto()
# Matrix | Vector
LinSolve = enum.auto()
LsqSolve = enum.auto()
# Vector | Matrix
VecMatOuter = enum.auto()
# Matrix | Matrix
MatMatDot = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
BO = BinaryOperation
return {
# Number | Number
BO.Mul: ' · r',
BO.Div: ' / r',
BO.Pow: ' ^ r', ## Also for square-matrix powers.
# Elements | Elements
BO.Add: ' + r',
BO.Sub: ' - r',
BO.HadamMul: '𝐋𝐑',
BO.HadamDiv: '𝐋 ⊙/ 𝐑',
BO.HadamPow: '𝐥 ⊙^ 𝐫',
BO.Atan2: 'atan2(:x, r:y)',
# Vector | Vector
BO.VecVecDot: '𝐥 · 𝐫',
BO.Cross: 'cross(𝐥,𝐫)',
BO.VecVecOuter: '𝐥𝐫',
# Matrix | Vector
BO.LinSolve: '𝐋 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂',
# Vector | Matrix
BO.VecMatOuter: '𝐋𝐫',
# Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
BO = BinaryOperation
return (
str(self),
BO.to_name(self),
BO.to_name(self),
BO.to_icon(self),
i,
)
def bl_enum_elements(
self, info_l: ct.InfoFlow, info_r: ct.InfoFlow
) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
Returns a `bpy.props.EnumProperty.items`-compatible list.
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r))
]
####################
# - Ops from Shape
####################
@staticmethod
def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]:
"""Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation
ops = []
# Add/Sub
if info_l.compare_addable(info_r, allow_differing_unit=True):
ops += [BO.Add, BO.Sub]
# Mul/Div
## -> Mul is ambiguous; we differentiate Hadamard and Standard.
## -> Div additionally requires non-zero guarantees.
if info_l.compare_multiplicable(info_r):
match (info_l.order, info_r.order, info_r.output.is_nonzero):
case (ordl, ordr, True) if ordl == 0 and ordr == 0:
ops += [BO.Mul, BO.Div]
case (ordl, ordr, True) if ordl > 0 and ordr == 0:
ops += [BO.Mul, BO.Div]
case (ordl, ordr, True) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
case (ordl, ordr, True) if ordl > 0 and ordr > 0:
ops += [BO.HadamMul, BO.HadamDiv]
case (ordl, ordr, False) if ordl == 0 and ordr == 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr == 0:
ops += [BO.Mul]
case (ordl, ordr, True) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr > 0:
ops += [BO.HadamMul]
# Pow
## -> We distinguish between "Hadamard Power" and "Power".
## -> For scalars, they are the same (but we only expose "power").
## -> For matrices, square matrices can be exp'ed by int powers.
## -> Any other combination is well-defined by the Hadamard Power.
if info_l.compare_exponentiable(info_r):
match (info_l.order, info_r.order, info_r.output.mathtype):
case (ordl, ordr, _) if ordl == 0 and ordr == 0:
ops += [BO.Pow]
case (ordl, ordr, spux.MathType.Integer) if (
ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols
):
ops += [BO.Pow, BO.HadamPow]
case _:
ops += [BO.HadamPow]
# Operations by-Output Length
match (
info_l.output.shape_len,
info_r.output.shape_len,
):
# Number | Number
case (0, 0) if info_l.is_scalar and info_r.is_scalar:
# atan2: PhysicalType Must Both be Length | NonPhysical
## -> atan2() produces radians from Cartesian coordinates.
## -> This wouldn't make sense on non-Length / non-Unitless.
if (
info_l.output.physical_type is spux.PhysicalType.Length
and info_r.output.physical_type is spux.PhysicalType.Length
) or (
info_l.output.physical_type is spux.PhysicalType.NonPhysical
and info_l.output.unit is None
and info_r.output.physical_type is spux.PhysicalType.NonPhysical
and info_r.output.unit is None
):
ops += [BO.Atan2]
return ops
# Vector | Vector
case (1, 1) if info_l.compare_dims_identical(info_r):
outl = info_l.output
outr = info_r.output
# 1D Orders: Outer Product is Valid
## -> We can't do per-element outer product.
## -> However, it's still super useful on its own.
if info_l.order == 1 and info_r.order == 1:
ops += [BO.VecVecOuter]
# Vector | Vector
if outl.rows > outl.cols and outr.rows > outr.cols:
ops += [BO.VecVecDot]
# Covector | Vector
if outl.rows < outl.cols and outr.rows > outr.cols:
ops += [BO.MatMatDot]
# Vector | Covector
if outl.rows > outl.cols and outr.rows < outr.cols:
ops += [BO.MatMatDot]
# Covector | Covector
if outl.rows < outl.cols and outr.rows < outr.cols:
ops += [BO.VecVecDot]
# Cross Product
## -> Works great element-wise.
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
if (outl.rows == 3 and outr.rows == 3) or (
outl.cols == 3 and outl.cols == 3
):
ops += [BO.Cross]
# Vector | Matrix
## -> We can't do per-element outer product.
## -> However, it's still super useful on its own.
case (1, 2) if info_l.compare_dims_identical(
info_r
) and info_l.order == 1 and info_r.order == 2:
ops += [BO.VecMatOuter]
# Matrix | Vector
case (2, 1) if info_l.compare_dims_identical(info_r):
# Mat-Vec Dot: Enforce RHS Column Vector
if outr.rows > outl.cols:
ops += [BO.MatMatDot]
ops += [BO.LinSolve, BO.LsqSolve]
## Matrix | Matrix
case (2, 2):
ops += [BO.MatMatDot]
return ops
####################
# - Function Properties
####################
@property
def sp_func(self):
"""Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
def unit_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
BO.Add: BO.Add.sp_func,
BO.Sub: BO.Sub.sp_func,
BO.HadamMul: BO.Mul.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
BO.VecVecDot: BO.Mul.sp_func,
BO.Cross: BO.Mul.sp_func,
BO.VecVecOuter: BO.Mul.sp_func,
# Matrix | Vector
## -> A,b in Ax = b have units, and the equality must hold.
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
# Vector | Matrix
BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self]
@property
def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
## TODO: Scale the units of one side to the other.
BO = BinaryOperation
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
exprs[1].applyfunc(lambda el: 1 / el)
),
BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]),
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Vector
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self]
####################
# - Transforms
####################
def transform_funcs(self, func_l: ct.FuncFlow, func_r: ct.FuncFlow) -> ct.FuncFlow:
"""Transform two input functions according to the current operation."""
BO = BinaryOperation
# Add/Sub: Normalize Unit of RHS to LHS
## -> We can only add/sub identical units.
## -> To be nice, we only require identical PhysicalType.
## -> The result of a binary operation should have one unit.
if self is BO.Add or self is BO.Sub:
norm_func_r = func_r.scale_to_unit(func_l.func_output.unit)
else:
norm_func_r = func_r
return (func_l, norm_func_r).compose_within(
self.jax_func,
enclosing_func_output=self.transform_outputs(
func_l.func_output, norm_func_r.func_output
),
supports_jax=True,
)
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
Warnings:
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
If not, bad things will happen.
"""
return info_l.operate_output(
info_r,
lambda a, b: self.sp_func([a, b]),
lambda a, b: self.unit_func([a, b]),
)
####################
# - InfoFlow Transform
####################
def transform_outputs(
self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol
) -> sim_symbols.SimSymbol:
# TO = TransformOperation
return None
# match self:
# # Number | Number
# case TO.Mul:
# return
# case TO.Div:
# case TO.Pow:
# # Elements | Elements
# Add = enum.auto()
# Sub = enum.auto()
# HadamMul = enum.auto()
# HadamPow = enum.auto()
# HadamDiv = enum.auto()
# Atan2 = enum.auto()
# # Vector | Vector
# VecVecDot = enum.auto()
# Cross = enum.auto()
# VecVecOuter = enum.auto()
# # Matrix | Vector
# LinSolve = enum.auto()
# LsqSolve = enum.auto()
# # Vector | Matrix
# VecMatOuter = enum.auto()
# # Matrix | Matrix
# MatMatDot = enum.auto()

View File

@ -0,0 +1,116 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
log = logger.get(__name__)
class ReduceOperation(enum.StrEnum):
# Summary
Count = enum.auto()
# Statistics
Mean = enum.auto()
Std = enum.auto()
Var = enum.auto()
StdErr = enum.auto()
Min = enum.auto()
Q25 = enum.auto()
Median = enum.auto()
Q75 = enum.auto()
Max = enum.auto()
Mode = enum.auto()
# Reductions
Sum = enum.auto()
Prod = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
RO = ReduceOperation
return {
# Summary
RO.Count: '# [a]',
RO.Mode: 'mode [a]',
# Statistics
RO.Mean: 'μ [a]',
RO.Std: 'σ [a]',
RO.Var: 'σ² [a]',
RO.StdErr: 'stderr [a]',
RO.Min: 'min [a]',
RO.Q25: 'q₂₅ [a]',
RO.Median: 'median [a]',
RO.Q75: 'q₇₅ [a]',
RO.Min: 'max [a]',
# Reductions
RO.Sum: 'sum [a]',
RO.Prod: 'prod [a]',
}[value]
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
RO = ReduceOperation
return (
str(self),
RO.to_name(self),
RO.to_name(self),
RO.to_icon(self),
i,
)
####################
# - Derivation
####################
@staticmethod
def from_info(info: ct.InfoFlow) -> list[typ.Self]:
"""Derive valid reduction operations from the `InfoFlow` of the operand."""
pass
####################
# - Composable Functions
####################
@property
def jax_func(self):
RO = ReduceOperation
return {}[self]
####################
# - Transforms
####################
def transform_info(self, info: ct.InfoFlow):
pass

View File

@ -0,0 +1,336 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import jax.numpy as jnp
import jaxtyping as jtyp
from blender_maxwell.utils import logger, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
log = logger.get(__name__)
class TransformOperation(enum.StrEnum):
"""Valid operations for the `TransformMathNode`.
Attributes:
FreqToVacWL: Transform an frequency dimension to vacuum wavelength.
VacWLToFreq: Transform a vacuum wavelength dimension to frequency.
ConvertIdxUnit: Convert the unit of a dimension to a compatible unit.
SetIdxUnit: Set all properties of a dimension.
FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it.
**For 2D integer-indexed data only**.
IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type.
DimToVec: Fold the last dimension into the scalar output, creating a vector output type.
DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type.
FT: Compute the 1D fourier transform along a dimension.
New dimensional bounds are computing using the Nyquist Limit.
For higher dimensions, simply repeat along more dimensions.
InvFT1D: Compute the inverse 1D fourier transform along a dimension.
New dimensional bounds are computing using the Nyquist Limit.
For higher dimensions, simply repeat along more dimensions.
"""
# Covariant Transform
FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto()
ConvertIdxUnit = enum.auto()
SetIdxUnit = enum.auto()
FirstColToFirstIdx = enum.auto()
# Fold
IntDimToComplex = enum.auto()
DimToVec = enum.auto()
DimsToMat = enum.auto()
# Fourier
FT1D = enum.auto()
InvFT1D = enum.auto()
# TODO: Affine
## TODO
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓',
TO.ConvertIdxUnit: 'Convert Dim',
TO.SetIdxUnit: 'Set Dim',
TO.FirstColToFirstIdx: '1st Col → 1st Dim',
# Fold
TO.IntDimToComplex: '',
TO.DimToVec: '→ Vector',
TO.DimsToMat: '→ Matrix',
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier
TO.FT1D: 'FT',
TO.InvFT1D: 'iFT',
}[value]
@property
def name(self) -> str:
return TransformOperation.to_name(self)
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
TO = TransformOperation
return (
str(self),
TO.to_name(self),
TO.to_name(self),
TO.to_icon(self),
i,
)
####################
# - Methods
####################
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
match self:
case TO.FreqToVacWL:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Freq
]
case TO.VacWLToFreq:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Length
]
case TO.ConvertIdxUnit:
return [
dim
for dim in info.dims
if not info.has_idx_labels(dim)
and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None
]
case TO.SetIdxUnit:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
## ColDimToComplex: Implicit Last Dimension
## DimToVec: Implicit Last Dimension
## DimsToMat: Implicit Last 2 Dimensions
case TO.FT1D | TO.InvFT1D:
# Filter by Axis Uniformity
## -> FT requires uniform axis (aka. must be RangeFlow).
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
return []
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
operations = []
# Covariant Transform
## Freq -> VacWL
if TO.FreqToVacWL.valid_dims(info):
operations += [TO.FreqToVacWL]
## VacWL -> Freq
if TO.VacWLToFreq.valid_dims(info):
operations += [TO.VacWLToFreq]
## Convert Index Unit
if TO.ConvertIdxUnit.valid_dims(info):
operations += [TO.ConvertIdxUnit]
if TO.SetIdxUnit.valid_dims(info):
operations += [TO.SetIdxUnit]
## Column to First Index (Array)
if (
len(info.dims) == 2 # noqa: PLR2004
and info.first_dim.mathtype is spux.MathType.Integer
and info.last_dim.mathtype is spux.MathType.Integer
and info.output.shape_len == 0
):
operations += [TO.FirstColToFirstIdx]
# Fold
## Last Dim -> Complex
if (
len(info.dims) >= 1
and (
info.output.mathtype
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
)
and info.last_dim.mathtype is spux.MathType.Integer
and info.has_idx_labels(info.last_dim)
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
):
operations += [TO.IntDimToComplex]
## Last Dim -> Vector
if len(info.dims) >= 1 and info.output.shape_len == 0:
operations += [TO.DimToVec]
## Last Dim -> Matrix
if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
operations += [TO.DimsToMat]
# Fourier
if TO.FT1D.valid_dims(info):
operations += [TO.FT1D]
if TO.InvFT1D.valid_dims(info):
operations += [TO.InvFT1D]
return operations
####################
# - Function Properties
####################
def jax_func(self, axis: int | None = None):
TO = TransformOperation
return {
# Covariant Transform
## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis),
TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis),
TO.ConvertIdxUnit: lambda expr: expr,
TO.SetIdxUnit: lambda expr: expr,
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
# Fold
## -> To Complex: This should generally be a no-op.
TO.IntDimToComplex: lambda expr: jnp.squeeze(
expr.view(dtype=jnp.complex64), axis=-1
),
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FT1D: lambda expr: jnp.fft(expr, axis=axis),
TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis),
}[self]
def transform_info(
self,
info: ct.InfoFlow,
dim: sim_symbols.SimSymbol | None = None,
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
new_dim_name: str | None = None,
unit: spux.Unit | None = None,
physical_type: spux.PhysicalType | None = None,
) -> ct.InfoFlow:
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := dim),
sim_symbols.wl(unit),
info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=unit,
),
),
TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := dim),
sim_symbols.freq(unit),
info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=unit,
),
),
TO.ConvertIdxUnit: lambda: info.replace_dim(
dim,
dim.update(unit=unit),
(
info.dims[dim].rescale_to_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.SetIdxUnit: lambda: info.replace_dim(
dim,
dim.update(
sym_name=new_dim_name,
physical_type=physical_type,
unit=unit,
),
(
info.dims[dim].correct_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.FirstColToFirstIdx: lambda: info.replace_dim(
info.first_dim,
info.first_dim.update(
sym_name=new_dim_name,
mathtype=spux.MathType.from_jax_array(data_col),
physical_type=physical_type,
unit=unit,
),
ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
),
TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier
TO.FT1D: lambda: info.replace_dim(
dim,
[
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
info.dims[dim].bound_fourier_transform,
],
),
TO.InvFT1D: lambda: info.replace_dim(
info.last_dim,
[
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
## -> Note the midpoint may revert to 0.
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
info.dims[dim].bound_inv_fourier_transform,
],
),
}[self]()

View File

@ -26,7 +26,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets

View File

@ -20,225 +20,15 @@ import enum
import typing as typ
import bpy
import jax.lax as jlax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import bl_cache, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
class FilterOperation(enum.StrEnum):
"""Valid operations for the `FilterMathNode`.
Attributes:
DimToVec: Shift last dimension to output.
DimsToMat: Shift last 2 dimensions to output.
PinLen1: Remove a len(1) dimension.
Pin: Remove a len(n) dimension by selecting a particular index.
Swap: Swap the positions of two dimensions.
"""
# Slice
Slice = enum.auto()
SliceIdx = enum.auto()
# Pin
PinLen1 = enum.auto()
Pin = enum.auto()
PinIdx = enum.auto()
# Dimension
Swap = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
FO = FilterOperation
return {
# Slice
FO.Slice: '≈a[v₁:v₂]',
FO.SliceIdx: '=a[i:j]',
# Pin
FO.PinLen1: 'a[0] → a',
FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'a[i] → a',
# Reinterpret
FO.Swap: 'a₁ ↔ a₂',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
FO = FilterOperation
return (
str(self),
FO.to_name(self),
FO.to_name(self),
FO.to_icon(self),
i,
)
####################
# - Ops from Info
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
operations = []
# Slice
if info.dims:
operations.append(FO.SliceIdx)
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1)
## Pin | PinIdx
## -> There must be a dimension, full stop.
if info.dims:
operations += [FO.Pin, FO.PinIdx]
# Reinterpret
## Swap
## -> There must be at least two dimensions.
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
return operations
####################
# - Computed Properties
####################
@property
def func_args(self) -> list[sim_symbols.SimSymbol]:
FO = FilterOperation
return {
# Pin
FO.Pin: [sim_symbols.idx(None)],
FO.PinIdx: [sim_symbols.idx(None)],
}.get(self, [])
####################
# - Methods
####################
@property
def num_dim_inputs(self) -> None:
FO = FilterOperation
return {
# Slice
FO.Slice: 1,
FO.SliceIdx: 1,
# Pin
FO.PinLen1: 1,
FO.Pin: 1,
FO.PinIdx: 1,
# Reinterpret
FO.Swap: 2,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
match self:
# Slice
case FO.Slice:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
case FO.SliceIdx:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
# Pin
case FO.PinLen1:
return [
dim
for dim, dim_idx in info.dims.items()
if not info.has_idx_cont(dim) and len(dim_idx) == 1
]
case FO.Pin:
return info.dims
case FO.PinIdx:
return [dim for dim in info.dims if not info.has_idx_cont(dim)]
# Dimension
case FO.Swap:
return info.dims
return []
def are_dims_valid(
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
if self.num_dim_inputs == 1:
return dim_0 in self.valid_dims(info)
if self.num_dim_inputs == 2: # noqa: PLR2004
valid_dims = self.valid_dims(info)
return dim_0 in valid_dims and dim_1 in valid_dims
return False
####################
# - UI
####################
def jax_func(
self,
axis_0: int | None,
axis_1: int | None,
slice_tuple: tuple[int, int, int] | None = None,
):
FO = FilterOperation
return {
# Pin
FO.Slice: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
# Pin
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
}[self]
def transform_info(
self,
info: ct.InfoFlow,
dim_0: sim_symbols.SimSymbol,
dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
FO = FilterOperation
return {
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
}[self]()
class FilterMathNode(base.MaxwellSimNode):
r"""Applies a function that operates on the shape of the array.
@ -304,7 +94,7 @@ class FilterMathNode(base.MaxwellSimNode):
####################
# - Properties: Operation
####################
operation: FilterOperation = bl_cache.BLField(
operation: math_system.FilterOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@ -313,7 +103,9 @@ class FilterMathNode(base.MaxwellSimNode):
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(FilterOperation.by_info(self.expr_info))
for i, operation in enumerate(
math_system.FilterOperation.by_info(self.expr_info)
)
]
return []
@ -358,7 +150,7 @@ class FilterMathNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self):
FO = FilterOperation
FO = math_system.FilterOperation
match self.operation:
# Slice
case FO.SliceIdx:
@ -398,7 +190,7 @@ class FilterMathNode(base.MaxwellSimNode):
row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['active_dim_1'], text='')
if self.operation is FilterOperation.SliceIdx:
if self.operation is math_system.FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
####################
@ -434,7 +226,7 @@ class FilterMathNode(base.MaxwellSimNode):
## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
props['operation'] is FilterOperation.Pin
props['operation'] is math_system.FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
@ -460,7 +252,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Loose Sockets: Pin Dim by-Value
## -> Works with discrete points / labelled integers.
elif (
props['operation'] is FilterOperation.PinIdx
props['operation'] is math_system.FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
):
@ -594,7 +386,10 @@ class FilterMathNode(base.MaxwellSimNode):
# Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search.
if props['operation'] is FilterOperation.Pin and has_pinned_value:
if (
props['operation'] is math_system.FilterOperation.Pin
and has_pinned_value
):
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)
@ -605,7 +400,10 @@ class FilterMathNode(base.MaxwellSimNode):
)
# Pin by-Index
if props['operation'] is FilterOperation.PinIdx and has_pinned_axis:
if (
props['operation'] is math_system.FilterOperation.PinIdx
and has_pinned_axis
):
return params.compose_within(
enclosing_arg_targets=[sim_symbols.idx(None)],
enclosing_func_args=[sp.S(pinned_axis)],

View File

@ -16,363 +16,19 @@
"""Declares `MapMathNode`."""
import enum
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import bl_cache, logger
from .... import contracts as ct
from .... import sockets
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Operation Enum
####################
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
Attributes:
Real: Compute the real part of the input.
Imag: Compute the imaginary part of the input.
Abs: Compute the absolute value of the input.
Sq: Square the input.
Sqrt: Compute the (principal) square root of the input.
InvSqrt: Compute the inverse square root of the input.
Cos: Compute the cosine of the input.
Sin: Compute the sine of the input.
Tan: Compute the tangent of the input.
Acos: Compute the inverse cosine of the input.
Asin: Compute the inverse sine of the input.
Atan: Compute the inverse tangent of the input.
Norm2: Compute the 2-norm (aka. length) of the input vector.
Det: Compute the determinant of the input matrix.
Cond: Compute the condition number of the input matrix.
NormFro: Compute the frobenius norm of the input matrix.
Rank: Compute the rank of the input matrix.
Diag: Compute the diagonal vector of the input matrix.
EigVals: Compute the eigenvalues vector of the input matrix.
SvdVals: Compute the singular values vector of the input matrix.
Inv: Compute the inverse matrix of the input matrix.
Tra: Compute the transpose matrix of the input matrix.
Qr: Compute the QR-factorized matrices of the input matrix.
Chol: Compute the Cholesky-factorized matrices of the input matrix.
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
# By Number
Real = enum.auto()
Imag = enum.auto()
Abs = enum.auto()
Sq = enum.auto()
Sqrt = enum.auto()
InvSqrt = enum.auto()
Cos = enum.auto()
Sin = enum.auto()
Tan = enum.auto()
Acos = enum.auto()
Asin = enum.auto()
Atan = enum.auto()
Sinc = enum.auto()
# By Vector
Norm2 = enum.auto()
# By Matrix
Det = enum.auto()
Cond = enum.auto()
NormFro = enum.auto()
Rank = enum.auto()
Diag = enum.auto()
EigVals = enum.auto()
SvdVals = enum.auto()
Inv = enum.auto()
Tra = enum.auto()
Qr = enum.auto()
Chol = enum.auto()
Svd = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
MO = MapOperation
return {
# By Number
MO.Real: '(v)',
MO.Imag: 'Im(v)',
MO.Abs: '|v|',
MO.Sq: '',
MO.Sqrt: '√v',
MO.InvSqrt: '1/√v',
MO.Cos: 'cos v',
MO.Sin: 'sin v',
MO.Tan: 'tan v',
MO.Acos: 'acos v',
MO.Asin: 'asin v',
MO.Atan: 'atan v',
MO.Sinc: 'sinc v',
# By Vector
MO.Norm2: '||v||₂',
# By Matrix
MO.Det: 'det V',
MO.Cond: 'κ(V)',
MO.NormFro: '||V||_F',
MO.Rank: 'rank V',
MO.Diag: 'diag V',
MO.EigVals: 'eigvals V',
MO.SvdVals: 'svdvals V',
MO.Inv: 'V⁻¹',
MO.Tra: 'Vt',
MO.Qr: 'qr V',
MO.Chol: 'chol V',
MO.Svd: 'svd V',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
MO = MapOperation
return (
str(self),
MO.to_name(self),
MO.to_name(self),
MO.to_icon(self),
i,
)
####################
# - Ops from Shape
####################
@staticmethod
def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation
element_ops = [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match (info.output.rows, info.output.cols):
case (1, 1):
return element_ops
case (_, 1):
return [*element_ops, MO.Norm2]
case (rows, cols) if rows == cols:
## TODO: Check hermitian/posdef for cholesky.
## - Can we even do this with just the output symbol approach?
return [
*element_ops,
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
return []
####################
# - Function Properties
####################
@property
def sp_func(self):
MO = MapOperation
return {
# By Number
MO.Real: lambda expr: sp.re(expr),
MO.Imag: lambda expr: sp.im(expr),
MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
MO.Tan: lambda expr: sp.tan(expr),
MO.Acos: lambda expr: sp.acos(expr),
MO.Asin: lambda expr: sp.asin(expr),
MO.Atan: lambda expr: sp.atan(expr),
MO.Sinc: lambda expr: sp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: sp.sqrt(expr.T @ expr)[0],
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: sp.det(expr),
MO.Cond: lambda expr: expr.condition_number(),
MO.NormFro: lambda expr: expr.norm(ord='fro'),
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
MO.Qr: lambda expr: expr.QRdecomposition(),
MO.Chol: lambda expr: expr.cholesky(),
MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property
def jax_func(self):
MO = MapOperation
return {
# By Number
MO.Real: lambda expr: jnp.real(expr),
MO.Imag: lambda expr: jnp.imag(expr),
MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda expr: jnp.sqrt(expr),
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda expr: jnp.sin(expr),
MO.Tan: lambda expr: jnp.tan(expr),
MO.Acos: lambda expr: jnp.acos(expr),
MO.Asin: lambda expr: jnp.asin(expr),
MO.Atan: lambda expr: jnp.atan(expr),
MO.Sinc: lambda expr: jnp.sinc(expr),
# By Vector
# Vector -> Number
MO.Norm2: lambda expr: jnp.linalg.norm(expr, ord=2, axis=-1),
# By Matrix
# Matrix -> Number
MO.Det: lambda expr: jnp.linalg.det(expr),
MO.Cond: lambda expr: jnp.linalg.cond(expr),
MO.NormFro: lambda expr: jnp.linalg.matrix_norm(expr, ord='fro'),
MO.Rank: lambda expr: jnp.linalg.matrix_rank(expr),
# Matrix -> Vec
MO.Diag: lambda expr: jnp.diagonal(expr, axis1=-2, axis2=-1),
MO.EigVals: lambda expr: jnp.linalg.eigvals(expr),
MO.SvdVals: lambda expr: jnp.linalg.svdvals(expr),
# Matrix -> Matrix
MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices
MO.Qr: lambda expr: jnp.linalg.qr(expr),
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self]
def transform_info(self, info: ct.InfoFlow):
MO = MapOperation
return {
# By Number
MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector
MO.Norm2: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# By Matrix
MO.Det: lambda: info.update_output(
rows=1,
cols=1,
),
MO.Cond: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
),
MO.NormFro: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
MO.Rank: lambda: info.update_output(
mathtype=spux.MathType.Integer,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Matrix -> Vector ## TODO: ALL OF THESE
MO.Diag: lambda: info,
MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()
####################
# - Node
####################
class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data.
@ -495,7 +151,7 @@ class MapMathNode(base.MaxwellSimNode):
return info
return None
operation: MapOperation = bl_cache.BLField(
operation: math_system.MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@ -504,7 +160,9 @@ class MapMathNode(base.MaxwellSimNode):
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
for i, operation in enumerate(
math_system.MapOperation.by_expr_info(self.expr_info)
)
]
return []
@ -513,7 +171,7 @@ class MapMathNode(base.MaxwellSimNode):
####################
def draw_label(self):
if self.operation is not None:
return 'Map: ' + MapOperation.to_name(self.operation)
return 'Map: ' + math_system.MapOperation.to_name(self.operation)
return self.bl_label

View File

@ -14,354 +14,24 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
"""Implements the `OperateMathNode`.
See `blender_maxwell.maxwell_sim_nodes.math_system` for the actual mathematics implementation.
"""
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Operation Enum
####################
class BinaryOperation(enum.StrEnum):
"""Valid operations for the `OperateMathNode`.
Attributes:
Mul: Scalar multiplication.
Div: Scalar division.
Pow: Scalar exponentiation.
Add: Elementwise addition.
Sub: Elementwise subtraction.
HadamMul: Elementwise multiplication (hadamard product).
HadamPow: Principled shape-aware exponentiation (hadamard power).
Atan2: Quadrant-respecting 2D arctangent.
VecVecDot: Dot product for identically shaped vectors w/transpose.
Cross: Cross product between identically shaped 3D vectors.
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system.
VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
"""
# Number | Number
Mul = enum.auto()
Div = enum.auto()
Pow = enum.auto()
# Elements | Elements
Add = enum.auto()
Sub = enum.auto()
HadamMul = enum.auto()
# HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
Atan2 = enum.auto()
# Vector | Vector
VecVecDot = enum.auto()
Cross = enum.auto()
VecVecOuter = enum.auto()
# Matrix | Vector
LinSolve = enum.auto()
LsqSolve = enum.auto()
# Vector | Matrix
VecMatOuter = enum.auto()
# Matrix | Matrix
MatMatDot = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
BO = BinaryOperation
return {
# Number | Number
BO.Mul: ' · r',
BO.Div: ' / r',
BO.Pow: ' ^ r',
# Elements | Elements
BO.Add: ' + r',
BO.Sub: ' - r',
BO.HadamMul: '𝐋𝐑',
# BO.HadamPow: '𝐥 ⊙^ 𝐫',
BO.Atan2: 'atan2(:x, r:y)',
# Vector | Vector
BO.VecVecDot: '𝐥 · 𝐫',
BO.Cross: 'cross(𝐥,𝐫)',
BO.VecVecOuter: '𝐥𝐫',
# Matrix | Vector
BO.LinSolve: '𝐋 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂',
# Vector | Matrix
BO.VecMatOuter: '𝐋𝐫',
# Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
BO = BinaryOperation
return (
str(self),
BO.to_name(self),
BO.to_name(self),
BO.to_icon(self),
i,
)
####################
# - Ops from Shape
####################
@staticmethod
def by_infos(info_l: int, info_r: int) -> list[typ.Self]:
"""Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation
ops_el_el = [
BO.Add,
BO.Sub,
BO.HadamMul,
# BO.HadamPow,
]
outl = info_l.output
outr = info_r.output
match (outl.shape_len, outr.shape_len):
# Number | *
## Number | Number
case (0, 0):
ops = [
BO.Add,
BO.Sub,
BO.Mul,
]
# Check Non-Zero Right Hand Side
## -> Obviously, we can't ever divide by zero.
## -> Sympy's assumptions system must always guarantee rhs != 0.
## -> If it can't, then we simply don't expose division.
## -> The is_zero assumption must be provided elsewhere.
## -> NOTE: This may prevent some valid uses of division.
## -> Watch out for "division is missing" bugs.
if info_r.output.is_nonzero:
ops.append(BO.Div)
if (
info_l.output.physical_type == spux.PhysicalType.Length
and info_l.output.unit == info_r.output.unit
):
ops += [BO.Atan2]
return [*ops, BO.Pow]
## Number | Vector
case (0, 1):
return [BO.Mul] # , BO.HadamPow]
## Number | Matrix
case (0, 2):
return [BO.Mul] # , BO.HadamPow]
# Vector | *
## Vector | Number
case (1, 0):
return [BO.Mul] # , BO.HadamPow]
## Vector | Vector
case (1, 1):
ops = []
# Vector | Vector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows > outl.cols and outr.rows > outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Covector | Vector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
if outl.rows < outl.cols and outr.rows > outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Vector | Covector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
## -> These are both the same operation, in this case.
if outl.rows > outl.cols and outr.rows < outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Covector | Covector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows < outl.cols and outr.rows < outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Cross Product
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
if (outl.rows == 3 and outr.rows == 3) or (
outl.cols == 3 and outl.cols == 3
):
ops += [BO.Cross]
return ops_el_el + ops
## Vector | Matrix
case (1, 2):
return [BO.VecMatOuter]
# Matrix | *
## Matrix | Number
case (2, 0):
return [BO.Mul] # , BO.HadamPow]
## Matrix | Vector
case (2, 1):
prepend_ops = []
# Mat-Vec Dot: Enforce RHS Column Vector
if outr.rows > outl.cols:
prepend_ops += [BO.MatMatDot]
return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
## Matrix | Matrix
case (2, 2):
return [*ops_el_el, BO.MatMatDot]
return []
####################
# - Function Properties
####################
@property
def sp_func(self):
"""Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
def unit_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
BO.Add: BO.Add.sp_func,
BO.Sub: BO.Sub.sp_func,
BO.HadamMul: BO.Mul.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
BO.VecVecDot: BO.Mul.sp_func,
BO.Cross: BO.Mul.sp_func,
BO.VecVecOuter: BO.Mul.sp_func,
# Matrix | Vector
## -> A,b in Ax = b have units, and the equality must hold.
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
# Vector | Matrix
BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self]
@property
def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
## TODO: Scale the units of one side to the other.
BO = BinaryOperation
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
# BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Vector
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self]
####################
# - InfoFlow Transform
####################
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
Warnings:
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
If not, bad things will happen.
"""
return info_l.operate_output(
info_r,
lambda a, b: self.sp_func([a, b]),
lambda a, b: self.unit_func([a, b]),
)
####################
# - Node
####################
class OperateMathNode(base.MaxwellSimNode):
r"""Applies a binary function between two expressions.
@ -386,7 +56,7 @@ class OperateMathNode(base.MaxwellSimNode):
}
####################
# - Properties
# - Properties: Incoming InfoFlows
####################
@events.on_value_changed(
# Trigger
@ -415,6 +85,7 @@ class OperateMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None:
"""Computed `InfoFlow`s of both expressions."""
info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info)
info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info)
@ -426,19 +97,18 @@ class OperateMathNode(base.MaxwellSimNode):
return None
operation: BinaryOperation = bl_cache.BLField(
####################
# - Property: Operation
####################
operation: math_system.BinaryOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_infos'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
"""Retrieve valid operations based on the input `InfoFlow`s."""
if self.expr_infos is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
BinaryOperation.by_infos(*self.expr_infos)
)
]
return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos)
return []
####################
@ -451,7 +121,7 @@ class OperateMathNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
return 'Op: ' + BinaryOperation.to_name(self.operation)
return 'Op: ' + math_system.BinaryOperation.to_name(self.operation)
return self.bl_label
@ -464,7 +134,7 @@ class OperateMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='')
####################
# - FlowKind.Value|Func
# - FlowKind.Value
####################
@events.computes_output_socket(
'Expr',
@ -477,20 +147,22 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_value(self, props: dict, input_sockets: dict):
operation = props['operation']
"""Binary operation on two symbolic input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
has_expr_l_value = not ct.FlowSignal.check(expr_l)
has_expr_r_value = not ct.FlowSignal.check(expr_r)
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
operation = props['operation']
if has_expr_l_value and has_expr_r_value and operation is not None:
return operation.sp_func([expr_l, expr_r])
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@ -505,10 +177,7 @@ class OperateMathNode(base.MaxwellSimNode):
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(self, props, input_sockets, output_sockets):
operation = props['operation']
if operation is None:
return ct.FlowSignal.FlowPending
"""Binary operation on two lazy-defined input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
output_info = output_sockets['Expr']
@ -517,14 +186,9 @@ class OperateMathNode(base.MaxwellSimNode):
has_expr_r = not ct.FlowSignal.check(expr_r)
has_output_info = not ct.FlowSignal.check(output_info)
# Compute Jax Function
## -> The operation enum directly provides the appropriate function.
if has_expr_l and has_expr_r and has_output_info:
return (expr_l | expr_r).compose_within(
operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True,
)
operation = props['operation']
if operation is not None and has_expr_l and has_expr_r and has_output_info:
return self.operation.transform_funcs(expr_l, expr_r)
return ct.FlowSignal.FlowPending
####################
@ -541,22 +205,17 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
BO = BinaryOperation
operation = props['operation']
"""Transform the input information of both lazy inputs."""
info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R']
has_info_l = not ct.FlowSignal.check(info_l)
has_info_r = not ct.FlowSignal.check(info_r)
# Compute Info
## -> The operation enum directly provides the appropriate transform.
operation = props['operation']
if (
has_info_l
and has_info_r
and operation is not None
and operation in BO.by_infos(info_l, info_r)
has_info_l and has_info_r and operation is not None
# and operation in BO.by_infos(info_l, info_r)
):
return operation.transform_infos(info_l, info_r)
@ -576,15 +235,14 @@ class OperateMathNode(base.MaxwellSimNode):
},
)
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
operation = props['operation']
"""Merge the lazy input parameters."""
params_l = input_sockets['Expr L']
params_r = input_sockets['Expr R']
has_params_l = not ct.FlowSignal.check(params_l)
has_params_r = not ct.FlowSignal.check(params_r)
# Compute Params
## -> Operations don't add new parameters, so just concatenate L|R.
operation = props['operation']
if has_params_l and has_params_r and operation is not None:
return params_l | params_r

View File

@ -20,334 +20,18 @@ import enum
import typing as typ
import bpy
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Operation Enum
####################
class TransformOperation(enum.StrEnum):
"""Valid operations for the `TransformMathNode`.
Attributes:
FreqToVacWL: Transform an frequency dimension to vacuum wavelength.
VacWLToFreq: Transform a vacuum wavelength dimension to frequency.
ConvertIdxUnit: Convert the unit of a dimension to a compatible unit.
SetIdxUnit: Set all properties of a dimension.
FirstColToFirstIdx: Extract the first data column and set the first dimension's index array equal to it.
**For 2D integer-indexed data only**.
IntDimToComplex: Fold a last length-2 integer dimension into the output, transforming it from a real-like type to complex type.
DimToVec: Fold the last dimension into the scalar output, creating a vector output type.
DimsToMat: Fold the last two dimensions into the scalar output, creating a matrix output type.
FT: Compute the 1D fourier transform along a dimension.
New dimensional bounds are computing using the Nyquist Limit.
For higher dimensions, simply repeat along more dimensions.
InvFT1D: Compute the inverse 1D fourier transform along a dimension.
New dimensional bounds are computing using the Nyquist Limit.
For higher dimensions, simply repeat along more dimensions.
"""
# Covariant Transform
FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto()
ConvertIdxUnit = enum.auto()
SetIdxUnit = enum.auto()
FirstColToFirstIdx = enum.auto()
# Fold
IntDimToComplex = enum.auto()
DimToVec = enum.auto()
DimsToMat = enum.auto()
# Fourier
FT1D = enum.auto()
InvFT1D = enum.auto()
# TODO: Affine
## TODO
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓',
TO.ConvertIdxUnit: 'Convert Dim',
TO.SetIdxUnit: 'Set Dim',
TO.FirstColToFirstIdx: '1st Col → 1st Dim',
# Fold
TO.IntDimToComplex: '',
TO.DimToVec: '→ Vector',
TO.DimsToMat: '→ Matrix',
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier
TO.FT1D: 'FT',
TO.InvFT1D: 'iFT',
}[value]
@property
def name(self) -> str:
return TransformOperation.to_name(self)
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
TO = TransformOperation
return (
str(self),
TO.to_name(self),
TO.to_name(self),
TO.to_icon(self),
i,
)
####################
# - Methods
####################
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
match self:
case TO.FreqToVacWL:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Freq
]
case TO.VacWLToFreq:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Length
]
case TO.ConvertIdxUnit:
return [
dim
for dim in info.dims
if not info.has_idx_labels(dim)
and spux.PhysicalType.from_unit(dim.unit, optional=True) is not None
]
case TO.SetIdxUnit:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
## ColDimToComplex: Implicit Last Dimension
## DimToVec: Implicit Last Dimension
## DimsToMat: Implicit Last 2 Dimensions
case TO.FT1D | TO.InvFT1D:
# Filter by Axis Uniformity
## -> FT requires uniform axis (aka. must be RangeFlow).
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
return []
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
operations = []
# Covariant Transform
## Freq -> VacWL
if TO.FreqToVacWL.valid_dims(info):
operations += [TO.FreqToVacWL]
## VacWL -> Freq
if TO.VacWLToFreq.valid_dims(info):
operations += [TO.VacWLToFreq]
## Convert Index Unit
if TO.ConvertIdxUnit.valid_dims(info):
operations += [TO.ConvertIdxUnit]
if TO.SetIdxUnit.valid_dims(info):
operations += [TO.SetIdxUnit]
## Column to First Index (Array)
if (
len(info.dims) == 2 # noqa: PLR2004
and info.first_dim.mathtype is spux.MathType.Integer
and info.last_dim.mathtype is spux.MathType.Integer
and info.output.shape_len == 0
):
operations += [TO.FirstColToFirstIdx]
# Fold
## Last Dim -> Complex
if (
len(info.dims) >= 1
and (
info.output.mathtype
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
)
and info.last_dim.mathtype is spux.MathType.Integer
and info.has_idx_labels(info.last_dim)
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
):
operations += [TO.IntDimToComplex]
## Last Dim -> Vector
if len(info.dims) >= 1 and info.output.shape_len == 0:
operations += [TO.DimToVec]
## Last Dim -> Matrix
if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
operations += [TO.DimsToMat]
# Fourier
if TO.FT1D.valid_dims(info):
operations += [TO.FT1D]
if TO.InvFT1D.valid_dims(info):
operations += [TO.InvFT1D]
return operations
####################
# - Function Properties
####################
def jax_func(self, axis: int | None = None):
TO = TransformOperation
return {
# Covariant Transform
## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
TO.FreqToVacWL: lambda expr: jnp.flip(expr, axis=axis),
TO.VacWLToFreq: lambda expr: jnp.flip(expr, axis=axis),
TO.ConvertIdxUnit: lambda expr: expr,
TO.SetIdxUnit: lambda expr: expr,
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
# Fold
## -> To Complex: This should generally be a no-op.
TO.IntDimToComplex: lambda expr: jnp.squeeze(
expr.view(dtype=jnp.complex64), axis=-1
),
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FT1D: lambda expr: jnp.fft(expr, axis=axis),
TO.InvFT1D: lambda expr: jnp.ifft(expr, axis=axis),
}[self]
def transform_info(
self,
info: ct.InfoFlow,
dim: sim_symbols.SimSymbol | None = None,
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
new_dim_name: str | None = None,
unit: spux.Unit | None = None,
physical_type: spux.PhysicalType | None = None,
) -> ct.InfoFlow:
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := dim),
sim_symbols.wl(unit),
info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=unit,
),
),
TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := dim),
sim_symbols.freq(unit),
info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=unit,
),
),
TO.ConvertIdxUnit: lambda: info.replace_dim(
dim,
dim.update(unit=unit),
(
info.dims[dim].rescale_to_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.SetIdxUnit: lambda: info.replace_dim(
dim,
dim.update(
sym_name=new_dim_name,
physical_type=physical_type,
unit=unit,
),
(
info.dims[dim].correct_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.FirstColToFirstIdx: lambda: info.replace_dim(
info.first_dim,
info.first_dim.update(
sym_name=new_dim_name,
mathtype=spux.MathType.from_jax_array(data_col),
physical_type=physical_type,
unit=unit,
),
ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
),
TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier
TO.FT1D: lambda: info.replace_dim(
dim,
[
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
info.dims[dim].bound_fourier_transform,
],
),
TO.InvFT1D: lambda: info.replace_dim(
info.last_dim,
[
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
## -> Note the midpoint may revert to 0.
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
info.dims[dim].bound_inv_fourier_transform,
],
),
}[self]()
####################
# - Node
####################
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
@ -409,7 +93,7 @@ class TransformMathNode(base.MaxwellSimNode):
####################
# - Properties: Operation
####################
operation: TransformOperation = bl_cache.BLField(
operation: math_system.TransformOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
@ -419,7 +103,7 @@ class TransformMathNode(base.MaxwellSimNode):
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
TransformOperation.by_info(self.expr_info)
math_system.TransformOperation.by_info(self.expr_info)
)
]
return []
@ -461,7 +145,7 @@ class TransformMathNode(base.MaxwellSimNode):
)
def search_units(self) -> list[ct.BLEnumElement]:
TO = TransformOperation
TO = math_system.TransformOperation
match self.operation:
# Covariant Transform
case TO.ConvertIdxUnit if self.dim is not None:
@ -521,7 +205,7 @@ class TransformMathNode(base.MaxwellSimNode):
return spux.sp_to_str(self.new_unit)
def draw_label(self):
TO = TransformOperation
TO = math_system.TransformOperation
match self.operation:
case TO.FreqToVacWL if self.dim is not None:
return f'T: {self.dim.name_pretty} | 𝑓{self.new_unit_str}'
@ -556,7 +240,7 @@ class TransformMathNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
TO = TransformOperation
TO = math_system.TransformOperation
match self.operation:
case TO.ConvertIdxUnit:
row = layout.row(align=True)
@ -613,7 +297,7 @@ class TransformMathNode(base.MaxwellSimNode):
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = TransformOperation
TO = math_system.TransformOperation
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
@ -662,7 +346,7 @@ class TransformMathNode(base.MaxwellSimNode):
self, props: dict, input_sockets: dict
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = TransformOperation
TO = math_system.TransformOperation
operation = props['operation']
info = input_sockets['Expr'][ct.FlowKind.Info]

View File

@ -24,7 +24,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -22,7 +22,7 @@ import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct

View File

@ -22,7 +22,7 @@ import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct

View File

@ -19,7 +19,7 @@ import inspect
import typing as typ
from types import MappingProxyType
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import contracts as ct

View File

@ -21,7 +21,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets

View File

@ -16,12 +16,13 @@
import enum
import typing as typ
from fractions import Fraction
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@ -50,6 +51,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
)
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
## Use of NumberSize1D implicitly guarantees UI-realizability later.
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
@ -109,31 +112,110 @@ class SymbolConstantNode(base.MaxwellSimNode):
preview_value_re: float = bl_cache.BLField(0.0)
preview_value_im: float = bl_cache.BLField(0.0)
####################
# - Computed Properties
####################
@bl_cache.cached_bl_property(
depends_on={
'sym_name',
'size',
'mathtype',
'physical_type',
'unit',
'interval_finite_z',
'interval_finite_q',
'interval_finite_re',
'interval_inf',
'interval_closed',
'interval_finite_im',
'interval_inf_im',
'interval_closed_im',
}
)
def interval_finite(
self,
) -> (
tuple[int | Fraction | float, int | Fraction | float]
| tuple[tuple[float, float], tuple[float, float]]
):
"""Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
MT = spux.MathType
match self.mathtype:
case MT.Integer:
return self.interval_finite_z
case MT.Rational:
return [Fraction(*q) for q in self.interval_finite_q]
case MT.Real:
return self.interval_finite_re
case MT.Complex:
return (self.interval_finite_re, self.interval_finite_im)
@bl_cache.cached_bl_property(
depends_on={
'mathtype',
'preview_value_z',
'preview_value_q',
'preview_value_re',
'preview_value_im',
}
)
def preview_value(
self,
) -> int | Fraction | float | complex:
"""Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
MT = spux.MathType
match self.mathtype:
case MT.Integer:
return self.preview_value_z
case MT.Rational:
return Fraction(*self.preview_value_q)
case MT.Real:
return self.preview_value_re
case MT.Complex:
return complex(self.preview_value_re, self.preview_value_im)
@bl_cache.cached_bl_property(
depends_on={
'mathtype',
'interval_finite',
'interval_inf',
'interval_inf_im',
'interval_closed',
'interval_closed_im',
}
)
def domain(
self,
) -> sp.Interval | sp.sets.fancysets.CartesianComplexRegion:
"""Deduce the domain specified in the UI."""
MT = spux.MathType
match self.mathtype:
case MT.Integer | MT.Real | MT.Rational:
return sim_symbols.mk_interval(
self.interval_finite,
self.interval_inf,
self.interval_closed,
)
case MT.Complex:
region = self.interval_finite
domain_re = sim_symbols.mk_interval(
region[0],
self.interval_inf,
self.interval_closed,
)
domain_im = sim_symbols.mk_interval(
region[1],
self.interval_inf_im,
self.interval_closed_im,
)
return sp.ComplexRegion(domain_re, domain_im, polar=False)
####################
# - Computed Properties
####################
@bl_cache.cached_bl_property(
depends_on={
'sym_name',
'mathtype',
'physical_type',
'unit',
'size',
'domain',
'preview_value',
}
)
def symbol(self) -> sim_symbols.SimSymbol:
"""Generate the `SimSymbol` matching the user-specification."""
return sim_symbols.SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
@ -141,18 +223,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
preview_value_z=self.preview_value_z,
preview_value_q=self.preview_value_q,
preview_value_re=self.preview_value_re,
preview_value_im=self.preview_value_im,
domain=self.domain,
preview_value=self.preview_value,
)
####################

View File

@ -23,7 +23,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets

View File

@ -23,7 +23,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets

View File

@ -23,7 +23,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger, sci_constants
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets

View File

@ -25,7 +25,7 @@ from tidy3d.material_library.material_library import MaterialItem as Tidy3DMediu
from tidy3d.material_library.material_library import VariantItem as Tidy3DMediumVariant
from blender_maxwell.utils import bl_cache, logger, sci_constants
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -23,7 +23,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -23,7 +23,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -20,7 +20,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import contracts as ct

View File

@ -20,7 +20,7 @@ from pathlib import Path
import bpy
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets

View File

@ -21,7 +21,7 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets

View File

@ -22,7 +22,7 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import contracts as ct

View File

@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -22,7 +22,7 @@ import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -28,7 +28,7 @@ from tidy3d.components.data.data_array import TimeDataArray as td_TimeDataArray
from tidy3d.components.data.dataset import TimeDataset as td_TimeDataset
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets

View File

@ -20,7 +20,7 @@ import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from ... import bl_socket_map, managed_objs, sockets

View File

@ -24,7 +24,7 @@ import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets

View File

@ -21,7 +21,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct

View File

@ -21,7 +21,7 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .... import contracts as ct

View File

@ -24,7 +24,7 @@ import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
from . import base
@ -155,12 +155,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
is_constant=True,
## TODO: Should we set preview values
exclude_zero=(
not self.value.is_zero
if self.value.is_zero is not None
else False
),
## TODO: Does this work for matrix elements?
## TODO: Does this 0-check work for matrix elements?
)
case ct.FlowKind.Range if self.symbols:
@ -208,7 +210,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
sim_symbols.SimSymbolName.Expr
)
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
sim_symbols.SimSymbolName.Constant
)
symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
@ -441,12 +443,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False):
"""Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`."""
if self.unit is not None:
pyvalue = spux.sympy_to_python(spux.scale_to_unit(expr, self.unit))
else:
pyvalue = spux.sympy_to_python(expr)
pyvalue = spux.scale_to_unit(expr, self.unit)
# Cast complex -> tuple[float, float]
## -> We can't set complex to BLProps.
## -> We must deconstruct it appropriately.
if isinstance(pyvalue, complex) or (
isinstance(pyvalue, int | float) and force_complex
):

View File

@ -24,7 +24,7 @@ import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from .. import base

View File

@ -19,7 +19,7 @@ import sympy as sp
import sympy.physics.optics.polarization as spo_pol
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from .. import base

View File

@ -17,22 +17,22 @@
from ..nodeps.utils import blender_type_enum, pydeps
from . import (
bl_cache,
extra_sympy_units,
image_ops,
logger,
sci_constants,
serialize,
staticproperty,
sympy_extra,
)
__all__ = [
'blender_type_enum',
'pydeps',
'bl_cache',
'extra_sympy_units',
'image_ops',
'logger',
'sci_constants',
'serialize',
'staticproperty',
'sympy_extra',
]

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ import matplotlib.figure
import seaborn as sns
from blender_maxwell import contracts as ct
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
sns.set_theme()

View File

@ -32,7 +32,7 @@ import scipy as sc
import sympy as sp
import sympy.physics.units as spu
from . import extra_sympy_units as spux
from . import sympy_extra as spux
SUPPORTED_SCIPY_PREFIX = '1.12'
if not sc.version.full_version.startswith(SUPPORTED_SCIPY_PREFIX):

View File

@ -31,8 +31,8 @@ import uuid
import msgspec
import sympy as sp
from . import extra_sympy_units as spux
from . import logger
from . import sympy_extra as spux
log = logger.get(__name__)

View File

@ -25,8 +25,8 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from . import extra_sympy_units as spux
from . import logger, serialize
from . import sympy_extra as spux
int_min = -(2**64)
int_max = 2**64
@ -101,6 +101,12 @@ class SimSymbolName(enum.StrEnum):
BlochY = enum.auto()
BlochZ = enum.auto()
# New Backwards Compatible Entries
## -> Ordered lists carry a particular enum integer index.
## -> Therefore, anything but adding an index breaks backwards compat.
## -> ...With all previous files.
ConstantRange = enum.auto()
####################
# - UI
####################
@ -143,7 +149,8 @@ class SimSymbolName(enum.StrEnum):
}
| {
# Generic
SSN.Constant: 'constant',
SSN.Constant: 'cst',
SSN.ConstantRange: 'cst_range',
SSN.Expr: 'expr',
SSN.Data: 'data',
# Greek Letters
@ -210,24 +217,43 @@ def mk_interval(
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
interval_inf: tuple[bool, bool],
interval_closed: tuple[bool, bool],
unit_factor: typ.Literal[1] | spux.Unit,
) -> sp.Interval:
"""Create a symbolic interval from the tuples (and unit) defining it."""
return sp.Interval(
start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
start=(interval_finite[0] if not interval_inf[0] else -sp.oo),
end=(interval_finite[1] if not interval_inf[1] else sp.oo),
left_open=(True if interval_inf[0] else not interval_closed[0]),
right_open=(True if interval_inf[1] else not interval_closed[1]),
)
class SimSymbol(pyd.BaseModel):
"""A declarative representation of a symbolic variable.
"""A convenient, constrained representation of a symbolic variable suitable for many tasks.
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
The original motivation was to enhance `sp.Symbol` with greater flexibility, semantic context, and a UI-friendly representation.
Today, `SimSymbol` is a fully capable primitive for defining the interfaces between externally tracked mathematical elements, and planning the required operations between them.
A symbol represented as `SimSymbol` carries all the semantic meaning of that symbol, and comes with a comprehensive library of useful (computed) properties and methods.
It is immutable, hashable, and serializable, and as a `pydantic.BaseModel` with aggressive property caching, its performance properties should also be well-suited for use in the hot-loops of ex. UI draw methods.
Attributes:
sym_name: For humans and computers, symbol names induces a lot of implicit semantics.
mathtype: Symbols are associated with some set of valid values.
We choose to constrain `SimSymbol` to only associate with _mathematical_ (aka. number-like) sets.
This prohibits ex. booleans and predicate-logic applications, but eases a lot of burdens associated with actually using `SimSymbol`.
physical_type: Symbols may be associated with a particular unit dimension expression.
This allows the symbol to have _physical meaning_.
This information is **generally not** encoded in auxiliary attributes like `self.domain`, but **generally is** encoded by computed properties/methods.
unit: Symbols may be associated with a particular unit, which must be compatible with the `PhysicalType`.
**NOTE**: Unit expressions may well have physical meaning, without being strictly conformable to a pre-blessed `PhysicalType`s.
We do try to avoid such cases, but for the sake of correctness, our chosen convention is to let `self.physical_type` be "`NonPhysical`", while still allowing a unit.
size: Symbols may themselves have shape.
**NOTE**: We deliberately choose to constrain `SimSymbol`s to two dimensions, allowing them to represent scalars, vectors, covectors, and matrices, but **not** arbitrary tensors.
This is a practical tradeoff, made both to make it easier (in terms of mathematical analysis) to implement `SimSymbol`, but also to make it easier to define UI elements that drive / are driven by `SimSymbol`s.
domain: Symbols are associated with a _domain of valid values_, expressed with any mathematical set implemented as a subclass of `sympy.Set`.
By using a true symbolic set, we gain unbounded flexibility in how to define the validity of a set, including an extremely capable `* in self.domain` operator encapsulating a lot of otherwise very manual logic.
**NOTE** that `self.unit` is **not** baked into the domain, due to practicalities associated with subclasses of `sp.Set`.
This dataclass is UI-friendly, as it only uses field type annotations/defaults supported by `bl_cache.BLProp`.
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
"""
model_config = pyd.ConfigDict(frozen=True)
@ -238,11 +264,8 @@ class SimSymbol(pyd.BaseModel):
# Units
## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way.
## -> When 'self.physical_type' is NonPhysical, can only be None.
unit: spux.Unit | None = None
## -> TODO: We currently allowing units that don't match PhysicalType
## -> -- In particular, NonPhysical w/units means "unknown units".
## -> -- This is essential for the Scientific Constant Node.
# Size
## -> All SimSymbol sizes are "2D", but interpreted by convention.
@ -253,39 +276,96 @@ class SimSymbol(pyd.BaseModel):
rows: int = 1
cols: int = 1
# Scalar Domain: "Interval"
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI.
is_constant: bool = False
exclude_zero: bool = False
# Valid Domain
## -> Declares the valid set of values that may be given to this symbol.
## -> By convention, units are not encoded in the domain sp.Set.
## -> 'sp.Set's are extremely expressive and cool.
domain: spux.SympyExpr | None = None
interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0.0, 1.0)
interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False)
@functools.cached_property
def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet:
if self.rows > 1 or self.cols > 1:
return sp.matrices.MatrixSet(self.rows, self.cols, self.domain)
return self.domain
interval_finite_im: tuple[float, float] = (0.0, 1.0)
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
preview_value_z: int = 0
preview_value_q: tuple[int, int] = (0, 1)
preview_value_re: float = 0.0
preview_value_im: float = 0.0
preview_value: spux.SympyExpr | None = None
####################
# - Core
# - Validators
####################
## TODO: Check domain against MathType
## -- Surprisingly hard without a lot of special-casing.
## TODO: Check that size is valid for the PhysicalType.
## TODO: Check that constant value (domain=FiniteSet(cst)) is compatible with the MathType.
## TODO: Check that preview_value is in the domain.
@pyd.model_validator(mode='after')
def set_undefined_domain_from_mathtype(self) -> typ.Self:
"""When the domain is not set, then set it using the symbolic set of the MathType."""
if self.domain is None:
object.__setattr__(self, 'domain', self.mathtype.symbolic_set)
return self
@pyd.model_validator(mode='after')
def conform_undefined_preview_value_to_constant(self) -> typ.Self:
"""When the `SimSymbol` is a constant, but the preview value is not set, then set the preview value from the constant."""
if self.is_constant and not self.preview_value:
object.__setattr__(self, 'preview_value', self.constant_value)
return self
@pyd.model_validator(mode='after')
def conform_preview_value(self) -> typ.Self:
"""Conform the given preview value to the `SimSymbol`."""
if self.is_constant and not self.preview_value:
object.__setattr__(
self,
'preview_value',
self.conform(self.preview_value, strip_units=True),
)
return self
####################
# - Domain
####################
@functools.cached_property
def name(self) -> str:
"""Usable name for the symbol."""
return self.sym_name.name
def is_constant(self) -> bool:
"""When the symbol domain is a single-element `sp.FiniteSet`, then the symbol can be considered to be a constant."""
return isinstance(self.domain, sp.FiniteSet) and len(self.domain) == 1
@functools.cached_property
def constant_value(self) -> bool:
"""Get the constant when `is_constant` is True.
The `self.unit_factor` is multiplied onto the constant at this point.
"""
if self.is_constant:
return next(iter(self.domain)) * self.unit_factor
msg = 'Tried to get constant value of non-constant SimSymbol.'
raise ValueError(msg)
@functools.cached_property
def is_nonzero(self) -> bool:
"""Whether $0$ is a valid value for this symbol.
When shaped, $0$ refers to the relevant shaped object with all elements $0$.
Notes:
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
"""
return 0 in self.domain
####################
# - Labels
####################
@functools.cached_property
def name(self) -> str:
"""Usable string name for the symbol."""
return self.sym_name.name
@functools.cached_property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
@ -340,7 +420,8 @@ class SimSymbol(pyd.BaseModel):
return self.unit if self.unit is not None else sp.S(1)
@functools.cached_property
def size(self) -> tuple[int, ...] | None:
def size(self) -> spux.NumberSize1D | None:
"""The 1D number size of this `SimSymbol`, if it has one; else None."""
return {
(1, 1): spux.NumberSize1D.Scalar,
(2, 1): spux.NumberSize1D.Vec2,
@ -350,13 +431,17 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property
def shape(self) -> tuple[int, ...]:
"""Deterministic chosen shape of this `SimSymbol`.
Derived from `self.rows` and `self.cols`.
Is never `None`; instead, empty tuple `()` is used.
"""
match (self.rows, self.cols):
case (1, 1):
return ()
case (_, 1):
return (self.rows,)
case (1, _):
return (1, self.rows)
case (_, _):
return (self.rows, self.cols)
@ -365,116 +450,6 @@ class SimSymbol(pyd.BaseModel):
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return len(self.shape)
@functools.cached_property
def domain(self) -> sp.Interval | sp.Set:
"""Return the scalar domain of valid values for each element of the symbol.
For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
This interval **must** have the property`start <= stop`.
Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
"""
match self.mathtype:
case spux.MathType.Integer:
return mk_interval(
self.interval_finite_z,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Rational:
return mk_interval(
Fraction(*self.interval_finite_q),
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Real:
return mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Complex:
return (
mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
),
mk_interval(
self.interval_finite_im,
self.interval_inf_im,
self.interval_closed_im,
self.unit_factor,
),
)
@functools.cached_property
def valid_domain_value(self) -> spux.SympyExpr:
"""A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
match (self.domain.start.is_finite, self.domain.end.is_finite):
case (True, True):
if self.mathtype is spux.MathType.Integer:
return (self.domain.start + self.domain.end) // 2
return (self.domain.start + self.domain.end) / 2
case (True, False):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.start + one
case (False, True):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.end - one
case (False, False):
return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
@functools.cached_property
def is_nonzero(self) -> bool:
"""Whether or not the value of this symbol can ever be $0$.
Notes:
Most notably, this symbol cannot be used as the right hand side of a division operation when this property is `False`.
"""
if self.exclude_zero:
return True
def check_real_domain(real_domain):
return (
(
real_domain.left == 0
and real_domain.left_open
or real_domain.right == 0
and real_domain.right_open
)
or real_domain.left > 0
or real_domain.right < 0
)
if self.mathtype is spux.MathType.Complex:
return check_real_domain(self.domain[0]) and check_real_domain(
self.domain[1]
)
return check_real_domain(self.domain)
@functools.cached_property
def can_diff(self) -> bool:
"""Whether this symbol can be used as the input / output variable when differentiating."""
# Check Constants
## -> Constants (w/pinned values) are never differentiable.
if self.is_constant:
return False
# TODO: Discontinuities (especially across 0)?
return self.mathtype in [spux.MathType.Real, spux.MathType.Complex]
####################
# - Properties
####################
@ -511,9 +486,9 @@ class SimSymbol(pyd.BaseModel):
# Positive/Negative Assumption
if self.mathtype is not spux.MathType.Complex:
if self.domain.left >= 0:
if self.domain.inf >= 0:
mathtype_kwargs |= {'positive': True}
elif self.domain.right <= 0:
elif self.domain.sup < 0:
mathtype_kwargs |= {'negative': True}
# Scalar: Return Symbol
@ -571,7 +546,7 @@ class SimSymbol(pyd.BaseModel):
"""
if self.size is not None:
if self.unit in self.physical_type.valid_units:
return {
socket_info = {
'output_name': self.sym_name,
# Socket Interface
'size': self.size,
@ -580,23 +555,42 @@ class SimSymbol(pyd.BaseModel):
# Defaults: Units
'default_unit': self.unit,
'default_symbols': [],
# Defaults: FlowKind.Value
'default_value': self.conform(
self.valid_domain_value, strip_unit=True
),
# Defaults: FlowKind.Range
'default_min': self.conform(self.domain.start, strip_unit=True),
'default_max': self.conform(self.domain.end, strip_unit=True),
}
# Defaults: FlowKind.Value
if self.preview_value:
socket_info |= {
'default_value': self.conform(
self.preview_value, strip_unit=True
)
}
# Defaults: FlowKind.Range
if (
self.mathtype is not spux.MathType.Complex
and self.rows == 1
and self.cols == 1
):
socket_info |= {
'default_min': self.domain.inf,
'default_max': self.domain.sup,
}
## TODO: Handle discontinuities / disjointness / open boundaries.
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
raise NotImplementedError(msg)
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
raise NotImplementedError(msg)
####################
# - Operations
# - Operations: Raw Update
####################
def update(self, **kwargs) -> typ.Self:
"""Create a new `SimSymbol`, such that the given keyword arguments override the existing values."""
if not kwargs:
return self
def get_attr(attr: str):
_notfound = 'notfound'
if kwargs.get(attr, _notfound) is _notfound:
@ -610,61 +604,101 @@ class SimSymbol(pyd.BaseModel):
unit=get_attr('unit'),
rows=get_attr('rows'),
cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_re'),
interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'),
interval_inf_im=get_attr('interval_inf_im'),
interval_closed_im=get_attr('interval_closed_im'),
domain=get_attr('domain'),
)
def set_finite_domain( # noqa: PLR0913
self,
start: int | float,
end: int | float,
start_closed: bool = True,
end_closed: bool = True,
start_im: bool = float,
end_im: bool = float,
start_closed_im: bool = True,
end_closed_im: bool = True,
) -> typ.Self:
"""Update the symbol with a finite range."""
closed_re = (start_closed, end_closed)
closed_im = (start_closed_im, end_closed_im)
match self.mathtype:
case spux.MathType.Integer:
return self.update(
interval_finite_z=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Rational:
return self.update(
interval_finite_q=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Real:
return self.update(
interval_finite_re=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Complex:
return self.update(
interval_finite_re=(start, end),
interval_finite_im=(start_im, end_im),
interval_inf=(False, False),
interval_closed=closed_re,
interval_closed_im=closed_im,
)
####################
# - Operations: Comparison
####################
def compare(self, other: typ.Self) -> typ.Self:
"""Whether this SimSymbol can be considered equivalent to another, and thus universally usable in arbitrary mathematical operations together.
def set_size(self, rows: int, cols: int) -> typ.Self:
return self.update(rows=rows, cols=cols)
In particular, two attributes are ignored:
- **Name**: The particluar choices of name are not generally important.
- **Unit**: The particulars of unit equivilancy are not generally important; only that the `PhysicalType` is equal, and thus that they are compatible.
While not usable in all cases, this method ends up being very helpful for simplifying certain checks that would otherwise take up a lot of space.
"""
return (
self.mathtype is other.mathtype
and self.physical_type is other.physical_type
and self.compare_size(other)
and self.domain == other.domain
)
def compare_size(self, other: typ.Self) -> typ.Self:
"""Compare the size of this `SimSymbol` with another."""
return self.rows == other.rows and self.cols == other.cols
def compare_addable(
self, other: typ.Self, allow_differing_unit: bool = False
) -> bool:
"""Whether two `SimSymbol`s can be added."""
common = (
self.compare_size(other.output)
and self.physical_type is other.physical_type
and not (
self.physical_type is spux.NonPhysical
and self.unit is not None
and self.unit != other.unit
)
and not (
other.physical_type is spux.NonPhysical
and other.unit is not None
and self.unit != other.unit
)
)
if not allow_differing_unit:
return common and self.output.unit == other.output.unit
return common
def compare_multiplicable(self, other: typ.Self) -> bool:
"""Whether two `SimSymbol`s can be multiplied."""
return self.shape_len == 0 or self.compare_size(other)
def compare_exponentiable(self, other: typ.Self) -> bool:
"""Whether two `SimSymbol`s can be exponentiated.
"Hadamard Power" is defined for any combination of scalar/vector/matrix operands, for any `MathType` combination.
The only important thing to check is that the exponent cannot have a physical unit.
Sometimes, people write equations with units in the exponent.
This is a notational shorthand that only works in the context of an implicit, cancelling factor.
We reject such things.
See https://physics.stackexchange.com/questions/109995/exponential-or-logarithm-of-a-dimensionful-quantity
"""
return (
other.physical_type is spux.PhysicalType.NonPhysical and other.unit is None
)
####################
# - Operations: Copying Setters
####################
def set_constant(self, constant_value: spux.SympyType) -> typ.Self:
"""Set the constant value of this `SimSymbol`, by setting it as the only value in a `sp.FiniteSet` domain.
The `constant_value` will be conformed and stripped (with `self.conform()`) before being injected into the new `sp.FiniteSet` domain.
Warnings:
Keep in mind that domains do not encode units, for practical reasons related to the diverging ways in which various `sp.Set` subclasses interpret units.
This isn't noticeable in normal constant-symbol workflows, where the constant is retrieved using `self.constant_value` (which adds `self.unit_factor`).
However, **remember that retrieving the domain directly won't add the unit**.
Ye been warned!
"""
if self.is_constant:
return self.update(
domain=sp.FiniteSet(self.conform(constant_value, strip_unit=True))
)
msg = 'Tried to set constant value of non-constant SimSymbol.'
raise ValueError(msg)
####################
# - Operations: Conforming Mappers
####################
def conform(
self, sp_obj: spux.SympyType, strip_unit: bool = False
) -> spux.SympyType:
@ -732,6 +766,9 @@ class SimSymbol(pyd.BaseModel):
return res # noqa: RET504
####################
# - Creation
####################
@staticmethod
def from_expr(
sym_name: SimSymbolName,

View File

@ -0,0 +1,173 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Declares many useful primitives to greatly simplify working with `sympy` in the context of a unit-aware system."""
from .math_type import MathType
from .number_size import NumberSize1D, NumberSize2D
from .parse_cast import parse_shape, pretty_symbol, sp_to_str, sympy_to_python
from .physical_type import Dims, PhysicalType
from .sympy_expr import (
ComplexNumber,
ComplexSymbol,
ConstrSympyExpr,
IntNumber,
IntSymbol,
Number,
PhysicalComplexNumber,
PhysicalNumber,
PhysicalRealNumber,
RationalSymbol,
Real3DVector,
RealNumber,
RealSymbol,
ScalarUnitlessComplexExpr,
ScalarUnitlessRealExpr,
Symbol,
SympyExpr,
Unit,
UnitDimension,
)
from .sympy_type import SympyType
from .unit_analysis import (
compare_unit_dims,
compare_units_by_unit_dims,
convert_to_unit,
get_units,
scale_to_unit,
scaling_factor,
strip_units,
unit_dim_to_unit_dim_deps,
unit_str_to_unit,
unit_to_unit_dim_deps,
uses_units,
)
from .unit_system_analysis import (
convert_to_unit_system,
scale_to_unit_system,
strip_unit_system,
)
from .unit_systems import UNITS_SI, UnitSystem
from .units import (
UNIT_BY_SYMBOL,
UNIT_TO_1,
EHz,
GHz,
KHz,
MHz,
PHz,
THz,
exahertz,
femtometer,
femtosecond,
fm,
fs,
gigahertz,
hectopascal,
hPa,
kilohertz,
lm,
lumen,
mbar,
megahertz,
micronewton,
millibar,
millinewton,
mN,
nanonewton,
nN,
petahertz,
terahertz,
uN,
)
__all__ = [
'MathType',
'NumberSize1D',
'NumberSize2D',
'parse_shape',
'pretty_symbol',
'sp_to_str',
'sympy_to_python',
'Dims',
'PhysicalType',
'ComplexNumber',
'ComplexSymbol',
'ConstrSympyExpr',
'IntNumber',
'IntSymbol',
'Number',
'PhysicalComplexNumber',
'PhysicalNumber',
'PhysicalRealNumber',
'RationalSymbol',
'Real3DVector',
'RealNumber',
'RealSymbol',
'ScalarUnitlessComplexExpr',
'ScalarUnitlessRealExpr',
'Symbol',
'SympyExpr',
'Unit',
'UnitDimension',
'SympyType',
'compare_unit_dims',
'compare_units_by_unit_dims',
'convert_to_unit',
'get_units',
'scale_to_unit',
'scaling_factor',
'strip_units',
'unit_dim_to_unit_dim_deps',
'unit_str_to_unit',
'unit_to_unit_dim_deps',
'uses_units',
'strip_unit_system',
'UNITS_SI',
'UnitSystem',
'convert_to_unit_system',
'scale_to_unit_system',
'UNIT_BY_SYMBOL',
'UNIT_TO_1',
'EHz',
'GHz',
'KHz',
'MHz',
'PHz',
'THz',
'exahertz',
'femtometer',
'femtosecond',
'fm',
'fs',
'gigahertz',
'hectopascal',
'hPa',
'kilohertz',
'lm',
'lumen',
'mbar',
'megahertz',
'micronewton',
'millibar',
'millinewton',
'mN',
'nanonewton',
'nN',
'petahertz',
'terahertz',
'uN',
]

View File

@ -0,0 +1,362 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `MathType`, a convenient UI-friendly identifier of numerical identity."""
import enum
import sys
import typing as typ
from fractions import Fraction
import jax
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell import contracts as ct
from .. import logger
from .sympy_type import SympyType
log = logger.get(__name__)
class MathType(enum.StrEnum):
"""A convenient, UI-friendly identifier of a numerical object's identity."""
Integer = enum.auto()
Rational = enum.auto()
Real = enum.auto()
Complex = enum.auto()
####################
# - Checks
####################
@staticmethod
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'jax', 'expr'] | None:
"""Determine whether an object of arbitrary type can be considered to have a `MathType`.
- **Pure Python**: The numerical Python types (`int | Fraction | float | complex`) are all valid.
- **Expression**: Sympy types / expression are in general considered to have a valid MathType.
- **Jax**: Non-empty `jax` arrays with a valid numerical Python type as the first element are valid.
Returns:
A string literal indicating how to parse the object for a valid `MathType`.
If the presence of a MathType couldn't be deduced, then return None.
"""
if isinstance(obj, int | Fraction | float | complex):
return 'pytype'
if (
isinstance(obj, jax.Array)
and obj
and isinstance(obj.item(0), int | Fraction | float | complex)
):
return 'jax'
if isinstance(obj, sp.Basic | sp.MatrixBase):
return 'expr'
## TODO: Should we check deeper?
return None
####################
# - Creation
####################
@staticmethod
def from_expr(sp_obj: SympyType, optional: bool = False) -> type | None: # noqa: PLR0911
"""Deduce the `MathType` of an arbitrary sympy object (/expression).
The "assumptions" system of `sympy` is relied on to determine the key properties of the expression.
To this end, it's important to note several of the shortcomings of the "assumptions" system:
- All elements, especially symbols, must have well-defined assumptions, ex. `real=True`.
- Only the "narrowest" possible `MathType` will be deduced, ex. `5` may well be the result of a complex expression, but since it is now an integer, it will parse to `MathType.Integer`. This may break some
- For infinities, only real and complex infinities are distinguished between in `sympy` (`sp.oo` vs. `sp.zoo`) - aka. there is no "integer infinity" which will parse to `Integer` with this method.
Warnings:
Using the "assumptions" system like this requires a lot of rigor in the entire program.
Notes:
Any matrix-like object will have `MathType.combine()` run on all of its (flattened) elements.
This is an extremely **slow** operation, but accurate, according to the semantics of `MathType.combine()`.
Note that `sp.MatrixSymbol` _cannot have assumptions_, and thus shouldn't be used in `sp_obj`.
Returns:
A corresponding `MathType`; else, if `optional=True`, return `None`.
Raises:
ValueError: If no corresponding `MathType` could be determined, and `optional=False`.
"""
if isinstance(sp_obj, sp.MatrixBase):
return MathType.combine(
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
)
if sp_obj.is_integer:
return MathType.Integer
if sp_obj.is_rational:
return MathType.Rational
if sp_obj.is_real:
return MathType.Real
if sp_obj.is_complex:
return MathType.Complex
# Infinities
if sp_obj in [sp.oo, -sp.oo]:
return MathType.Real
if sp_obj in [sp.zoo, -sp.zoo]:
return MathType.Complex
if optional:
return None
msg = f"Can't determine MathType from sympy object: {sp_obj}"
raise ValueError(msg)
@staticmethod
def from_pytype(dtype: type) -> type:
return {
int: MathType.Integer,
Fraction: MathType.Rational,
float: MathType.Real,
complex: MathType.Complex,
}[dtype]
@staticmethod
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
"""Deduce the MathType corresponding to a JAX array.
We go about this by leveraging that:
- `data` is of a homogeneous type.
- `data.item(0)` returns a single element of the array w/pure-python type.
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
Notes:
Should also work with numpy arrays.
"""
if len(data) > 0:
return MathType.from_pytype(type(data.item(0)))
msg = 'Cannot determine MathType from empty jax array.'
raise ValueError(msg)
####################
# - Operations
####################
@staticmethod
def combine(*mathtypes: list[typ.Self], optional: bool = False) -> typ.Self | None:
if MathType.Complex in mathtypes:
return MathType.Complex
if MathType.Real in mathtypes:
return MathType.Real
if MathType.Rational in mathtypes:
return MathType.Rational
if MathType.Integer in mathtypes:
return MathType.Integer
if optional:
return None
msg = f"Can't combine mathtypes {mathtypes}"
raise ValueError(msg)
def is_compatible(self, other: typ.Self) -> bool:
MT = MathType
return (
other
in {
MT.Integer: [MT.Integer],
MT.Rational: [MT.Integer, MT.Rational],
MT.Real: [MT.Integer, MT.Rational, MT.Real],
MT.Complex: [MT.Integer, MT.Rational, MT.Real, MT.Complex],
}[self]
)
def coerce_compatible_pyobj(
self, pyobj: bool | int | Fraction | float | complex
) -> int | Fraction | float | complex:
"""Coerce a pure-python object of numerical type to the _exact_ type indicated by this `MathType`.
This is needed when ex. one has an integer, but it is important that that integer be passed as a complex number.
"""
MT = MathType
match self:
case MT.Integer:
return int(pyobj)
case MT.Rational if isinstance(pyobj, int):
return Fraction(pyobj, 1)
case MT.Rational if isinstance(pyobj, Fraction):
return pyobj
case MT.Real:
return float(pyobj)
case MT.Complex if isinstance(pyobj, int | Fraction):
return complex(float(pyobj), 0)
case MT.Complex if isinstance(pyobj, float):
return complex(pyobj, 0)
@staticmethod
def from_symbolic_set(
s: typ.Literal[
sp.Naturals
| sp.Naturals0
| sp.Integers
| sp.Rationals
| sp.Reals
| sp.Complexes
]
| sp.Set,
optional: bool = False,
) -> typ.Self | None:
"""Deduce the `MathType` from a particular symbolic set.
Currently hard-coded.
Any deviation that might be expected to work, ex. `sp.Reals - {0}`, currently won't (currently).
Raises:
ValueError: If a non-hardcoded symbolic set is passed.
"""
MT = MathType
match s:
case sp.Naturals | sp.Naturals0 | sp.Integers:
return MT.Integer
case sp.Rationals:
return MT.Rational
case sp.Reals:
return MT.Real
case sp.Complexes:
return MT.Complex
if optional:
return None
msg = f"Can't deduce MathType from symbolic set {s}"
raise ValueError(msg)
####################
# - Casting: Pytype
####################
@property
def pytype(self) -> type:
"""Deduce the pure-Python type that corresponds to this `MathType`."""
MT = MathType
return {
MT.Integer: int,
MT.Rational: Fraction,
MT.Real: float,
MT.Complex: complex,
}[self]
@property
def inf_finite(self) -> type:
"""Opinionated finite representation of "infinity" within this `MathType`.
These are chosen using `sys.maxsize` and `sys.float_info`.
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
Notes:
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
"""
MT = MathType
Z = MT.Integer
R = MT.Integer
return {
MT.Integer: (-sys.maxsize, sys.maxsize),
MT.Rational: (
Fraction(Z.inf_finite[0], 1),
Fraction(Z.inf_finite[1], 1),
),
MT.Real: -(sys.float_info.min, sys.float_info.max),
MT.Complex: (
complex(R.inf_finite[0], R.inf_finite[0]),
complex(R.inf_finite[1], R.inf_finite[1]),
),
}[self]
####################
# - Casting: Symbolic
####################
@property
def symbolic_set(self) -> sp.Set:
"""Deduce the symbolic `sp.Set` type that corresponds to this `MathType`."""
MT = MathType
return {
MT.Integer: sp.Integers,
MT.Rational: sp.Rationals,
MT.Real: sp.Reals,
MT.Complex: sp.Complexes,
}[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
return {
MT.Integer: sp.Symbol('a', integer=True),
MT.Rational: sp.Symbol('a', rational=True),
MT.Real: sp.Symbol('a', real=True),
MT.Complex: sp.Symbol('a', complex=True),
}[self]
####################
# - Labels
####################
@staticmethod
def to_str(value: typ.Self) -> type:
return {
MathType.Integer: '',
MathType.Rational: '',
MathType.Real: '',
MathType.Complex: '',
}[value]
@property
def name(self) -> str:
"""Simple non-unicode name of the math type."""
return str(self)
@property
def label_pretty(self) -> str:
return MathType.to_str(self)
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
return MathType.to_str(value)
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
return (
str(self),
MathType.to_name(self),
MathType.to_name(self),
MathType.to_icon(self),
i,
)

View File

@ -0,0 +1,148 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
import sympy as sp
from blender_maxwell import contracts as ct
####################
# - Size: 1D
####################
class NumberSize1D(enum.StrEnum):
"""Valid 1D-constrained shape."""
Scalar = enum.auto()
Vec2 = enum.auto()
Vec3 = enum.auto()
Vec4 = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
NS = NumberSize1D
return {
NS.Scalar: 'Scalar',
NS.Vec2: '2D',
NS.Vec3: '3D',
NS.Vec4: '4D',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
NS = NumberSize1D
return {
NS.Scalar: '',
NS.Vec2: '',
NS.Vec3: '',
NS.Vec4: '',
}[value]
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
return (
str(self),
NumberSize1D.to_name(self),
NumberSize1D.to_name(self),
NumberSize1D.to_icon(self),
i,
)
@staticmethod
def has_shape(shape: tuple[int, ...] | None):
return shape in [None, (2,), (3,), (4,), (2, 1), (3, 1), (4, 1)]
def supports_shape(self, shape: tuple[int, ...] | None):
NS = NumberSize1D
match self:
case NS.Scalar:
return shape is None
case NS.Vec2:
return shape in ((2,), (2, 1))
case NS.Vec3:
return shape in ((3,), (3, 1))
case NS.Vec4:
return shape in ((4,), (4, 1))
@staticmethod
def from_shape(shape: tuple[typ.Literal[2, 3]] | None) -> typ.Self:
NS = NumberSize1D
return {
None: NS.Scalar,
(2,): NS.Vec2,
(3,): NS.Vec3,
(4,): NS.Vec4,
(2, 1): NS.Vec2,
(3, 1): NS.Vec3,
(4, 1): NS.Vec4,
}[shape]
@property
def rows(self):
NS = NumberSize1D
return {
NS.Scalar: 1,
NS.Vec2: 2,
NS.Vec3: 3,
NS.Vec4: 4,
}[self]
@property
def cols(self):
return 1
@property
def shape(self):
NS = NumberSize1D
return {
NS.Scalar: None,
NS.Vec2: (2,),
NS.Vec3: (3,),
NS.Vec4: (4,),
}[self]
def symbol_range(sym: sp.Symbol) -> str:
return f'{sym.name}' + (
''
if sym.is_complex
else ('' if sym.is_real else ('' if sym.is_integer else '?'))
)
####################
# - Symbol Sizes
####################
class NumberSize2D(enum.StrEnum):
"""Simple subset of sizes for rank-2 tensors."""
Scalar = enum.auto()
# Vectors
Vec2 = enum.auto() ## 2x1
Vec3 = enum.auto() ## 3x1
Vec4 = enum.auto() ## 4x1
# Covectors
CoVec2 = enum.auto() ## 1x2
CoVec3 = enum.auto() ## 1x3
CoVec4 = enum.auto() ## 1x4
# Square Matrices
Mat22 = enum.auto() ## 2x2
Mat33 = enum.auto() ## 3x3
Mat44 = enum.auto() ## 4x4

View File

@ -0,0 +1,119 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import jax
import jax.numpy as jnp
import sympy as sp
from .. import logger
from .sympy_type import SympyType
log = logger.get(__name__)
####################
# - Parsing: Info from SympyType
####################
def parse_shape(sp_obj: SympyType) -> int | None:
if isinstance(sp_obj, sp.MatrixBase):
return sp_obj.shape
return None
####################
# - Casting: Python
####################
def sympy_to_python(
scalar: sp.Basic, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array:
"""Convert a scalar sympy expression to the directly corresponding Python type.
Arguments:
scalar: A sympy expression that has no symbols, but is expressed as a Sympy type.
For expressions that are equivalent to a scalar (ex. "(2a + a)/a"), you must simplify the expression with ex. `sp.simplify()` before passing to this parameter.
Returns:
A pure Python type that directly corresponds to the input scalar expression.
"""
if isinstance(scalar, sp.MatrixBase):
# Detect Single Column Vector
## --> Flatten to Single Row Vector
if len(scalar.shape) == 2 and scalar.shape[1] == 1:
_scalar = scalar.T
else:
_scalar = scalar
# Convert to Tuple of Tuples
matrix = tuple(
[tuple([sympy_to_python(el) for el in row]) for row in _scalar.tolist()]
)
# Detect Single Row Vector
## --> This could be because the scalar had it.
## --> This could also be because we flattened a column vector.
## Either way, we should strip the pointless dimensions.
if len(matrix) == 1:
return matrix[0] if not use_jax_array else jnp.array(matrix[0])
return matrix if not use_jax_array else jnp.array(matrix)
if scalar.is_integer:
return int(scalar)
if scalar.is_rational or scalar.is_real:
return float(scalar)
if scalar.is_complex:
return complex(scalar)
msg = f'Cannot convert sympy scalar expression "{scalar}" to a Python type. Check the assumptions on the expr (current expr assumptions: "{scalar._assumptions}")' # noqa: SLF001
raise ValueError(msg)
####################
# - Casting: Printing
####################
_SYMPY_EXPR_PRINTER_STR = sp.printing.str.StrPrinter(
settings={
'abbrev': True,
}
)
def sp_to_str(sp_obj: SympyType) -> str:
"""Converts a sympy object to an output-oriented string (w/abbreviated units), using a dedicated StrPrinter.
This should be used whenever a **string for UI use** is needed from a `sympy` object.
Notes:
This should **NOT** be used in cases where the string will be `sp.sympify()`ed back into a sympy expression.
For such cases, rely on `sp.srepr()`, which uses an _explicit_ representation.
Parameters:
sp_obj: The `sympy` object to convert to a string.
Returns:
A string representing the expression for human use.
_The string is not re-encodable to the expression._
"""
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
def pretty_symbol(sym: sp.Symbol) -> str:
return f'{sym.name}' + (
''
if sym.is_integer
else ('' if sym.is_real else ('' if sym.is_complex else '?'))
)

View File

@ -0,0 +1,644 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `PhysicalType`, a convenient, UI-friendly way of deterministically handling the unit-dimensionality of arbitrary objects."""
import enum
import functools
import typing as typ
import sympy.physics.units as spu
from blender_maxwell import contracts as ct
from ..staticproperty import staticproperty
from . import units as spux
from .math_type import MathType
from .sympy_expr import Unit
from .sympy_type import SympyType
from .unit_analysis import (
compare_unit_dim_to_unit_dim_deps,
compare_unit_dims,
unit_to_unit_dim_deps,
)
####################
# - Unit Dimensions
####################
class DimsMeta(type):
"""Metaclass allowing an implementing (ideally empty) class to access `spu.definitions.dimension_definitions` attributes directly via its own attribute."""
def __getattr__(cls, attr: str) -> spu.Dimension:
"""Alias for `spu.definitions.dimension_definitions.*` (isn't that a mouthful?).
Raises:
AttributeError: If the name cannot be found.
"""
if (
attr in spu.definitions.dimension_definitions.__dir__()
and not attr.startswith('__')
):
return getattr(spu.definitions.dimension_definitions, attr)
raise AttributeError(name=attr, obj=Dims)
class Dims(metaclass=DimsMeta):
"""Access `sympy.physics.units` dimensions with less hassle.
Any unit dimension available in `sympy.physics.units.definitions.dimension_definitions` can be accessed as an attribute of `Dims`.
An `AttributeError` is raised if the unit cannot be found in `sympy`.
Examples:
The objects returned are a direct alias to `sympy`, with less hassle:
```python
assert Dims.length == (
sympy.physics.units.definitions.dimension_definitions.length
)
```
"""
####################
# - Physical Type
####################
class PhysicalType(enum.StrEnum):
"""An identifier of unit dimensionality with many useful properties."""
# Unitless
NonPhysical = enum.auto()
# Global
Time = enum.auto()
Angle = enum.auto()
SolidAngle = enum.auto()
## TODO: Some kind of 3D-specific orientation ex. a quaternion
Freq = enum.auto()
AngFreq = enum.auto() ## rad*hertz
# Cartesian
Length = enum.auto()
Area = enum.auto()
Volume = enum.auto()
# Mechanical
Vel = enum.auto()
Accel = enum.auto()
Mass = enum.auto()
Force = enum.auto()
Pressure = enum.auto()
# Energy
Work = enum.auto() ## joule
Power = enum.auto() ## watt
PowerFlux = enum.auto() ## watt
Temp = enum.auto()
# Electrodynamics
Current = enum.auto() ## ampere
CurrentDensity = enum.auto()
Charge = enum.auto() ## coulomb
Voltage = enum.auto()
Capacitance = enum.auto() ## farad
Impedance = enum.auto() ## ohm
Conductance = enum.auto() ## siemens
Conductivity = enum.auto() ## siemens / length
MFlux = enum.auto() ## weber
MFluxDensity = enum.auto() ## tesla
Inductance = enum.auto() ## henry
EField = enum.auto()
HField = enum.auto()
# Luminal
LumIntensity = enum.auto()
LumFlux = enum.auto()
Illuminance = enum.auto()
####################
# - Unit Dimensions
####################
@functools.cached_property
def unit_dim(self) -> SympyType:
"""The unit dimension expression associated with the `PhysicalType`.
A `PhysicalType` is, in its essence, merely an identifier for a particular unit dimension expression.
"""
PT = PhysicalType
return {
PT.NonPhysical: None,
# Global
PT.Time: Dims.time,
PT.Angle: Dims.angle,
PT.SolidAngle: spu.steradian.dimension, ## MISSING
PT.Freq: Dims.frequency,
PT.AngFreq: Dims.angle * Dims.frequency,
# Cartesian
PT.Length: Dims.length,
PT.Area: Dims.length**2,
PT.Volume: Dims.length**3,
# Mechanical
PT.Vel: Dims.length / Dims.time,
PT.Accel: Dims.length / Dims.time**2,
PT.Mass: Dims.mass,
PT.Force: Dims.force,
PT.Pressure: Dims.pressure,
# Energy
PT.Work: Dims.energy,
PT.Power: Dims.power,
PT.PowerFlux: Dims.power / Dims.length**2,
PT.Temp: Dims.temperature,
# Electrodynamics
PT.Current: Dims.current,
PT.CurrentDensity: Dims.current / Dims.length**2,
PT.Charge: Dims.charge,
PT.Voltage: Dims.voltage,
PT.Capacitance: Dims.capacitance,
PT.Impedance: Dims.impedance,
PT.Conductance: Dims.conductance,
PT.Conductivity: Dims.conductance / Dims.length,
PT.MFlux: Dims.magnetic_flux,
PT.MFluxDensity: Dims.magnetic_density,
PT.Inductance: Dims.inductance,
PT.EField: Dims.voltage / Dims.length,
PT.HField: Dims.current / Dims.length,
# Luminal
PT.LumIntensity: Dims.luminous_intensity,
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
}[self]
@staticproperty
def unit_dims() -> dict[typ.Self, SympyType]:
"""All unit dimensions supported by all `PhysicalType`s."""
return {
physical_type: physical_type.unit_dim
for physical_type in list(PhysicalType)
}
####################
# - Convenience Properties
####################
@functools.cached_property
def default_unit(self) -> list[Unit]:
"""Subjective choice of 'default' unit from `self.valid_units`.
There is no requirement to use this.
"""
PT = PhysicalType
return {
PT.NonPhysical: None,
# Global
PT.Time: spu.picosecond,
PT.Angle: spu.radian,
PT.SolidAngle: spu.steradian,
PT.Freq: spux.terahertz,
PT.AngFreq: spu.radian * spux.terahertz,
# Cartesian
PT.Length: spu.micrometer,
PT.Area: spu.um**2,
PT.Volume: spu.um**3,
# Mechanical
PT.Vel: spu.um / spu.second,
PT.Accel: spu.um / spu.second,
PT.Mass: spu.microgram,
PT.Force: spux.micronewton,
PT.Pressure: spux.millibar,
# Energy
PT.Work: spu.joule,
PT.Power: spu.watt,
PT.PowerFlux: spu.watt / spu.meter**2,
PT.Temp: spu.kelvin,
# Electrodynamics
PT.Current: spu.ampere,
PT.CurrentDensity: spu.ampere / spu.meter**2,
PT.Charge: spu.coulomb,
PT.Voltage: spu.volt,
PT.Capacitance: spu.farad,
PT.Impedance: spu.ohm,
PT.Conductance: spu.siemens,
PT.Conductivity: spu.siemens / spu.micrometer,
PT.MFlux: spu.weber,
PT.MFluxDensity: spu.tesla,
PT.Inductance: spu.henry,
PT.EField: spu.volt / spu.micrometer,
PT.HField: spu.ampere / spu.micrometer,
# Luminal
PT.LumIntensity: spu.candela,
PT.LumFlux: spu.candela * spu.steradian,
PT.Illuminance: spu.candela / spu.meter**2,
}[self]
####################
# - Creation
####################
@staticmethod
def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None:
"""Attempt to determine a matching `PhysicalType` from a unit.
NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`.
Returns:
The matched `PhysicalType`.
If none could be matched, then either return `None` (if `optional` is set) or error.
Raises:
ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
"""
if unit is None:
return PhysicalType.NonPhysical
## TODO_ This enough?
if unit in [spu.radian, spu.degree]:
return PhysicalType.Angle
unit_dim_deps = unit_to_unit_dim_deps(unit)
if unit_dim_deps is not None:
for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
if compare_unit_dim_to_unit_dim_deps(candidate_unit_dim, unit_dim_deps):
return physical_type
if optional:
return None
msg = f'Could not determine PhysicalType for {unit}'
raise ValueError(msg)
@staticmethod
def from_unit_dim(
unit_dim: SympyType | None, optional: bool = False
) -> typ.Self | None:
"""Attempts to match an arbitrary unit dimension expression to a corresponding `PhysicalType`.
For comparing arbitrary unit dimensions (via expressions of `spu.dimensions.Dimension`), it is critical that equivalent dimensions are also compared as equal (ex. `mass*length/time^2 == force`).
To do so, we employ the `SI` unit conventions, for extracting the fundamental dimensional dependencies of unit dimension expressions.
Returns:
The matched `PhysicalType`.
If none could be matched, then either return `None` (if `optional` is set) or error.
Raises:
ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
"""
for physical_type, candidate_unit_dim in PhysicalType.unit_dims.items():
if compare_unit_dims(unit_dim, candidate_unit_dim):
return physical_type
if optional:
return None
msg = f'Could not determine PhysicalType for {unit_dim}'
raise ValueError(msg)
####################
# - Valid Properties
####################
@functools.cached_property
def valid_units(self) -> list[Unit]:
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
Warnings:
**Altering the order of units hard-breaks backwards compatibility**, since enums based on it only keep an integer index.
Notes:
The order in which valid units are declared is the exact same order that UI dropdowns display them.
"""
PT = PhysicalType
return {
PT.NonPhysical: [None],
# Global
PT.Time: [
spu.picosecond,
spux.femtosecond,
spu.nanosecond,
spu.microsecond,
spu.millisecond,
spu.second,
spu.minute,
spu.hour,
spu.day,
],
PT.Angle: [
spu.radian,
spu.degree,
],
PT.SolidAngle: [
spu.steradian,
],
PT.Freq: (
_valid_freqs := [
spux.terahertz,
spu.hertz,
spux.kilohertz,
spux.megahertz,
spux.gigahertz,
spux.petahertz,
spux.exahertz,
]
),
PT.AngFreq: [spu.radian * _unit for _unit in _valid_freqs],
# Cartesian
PT.Length: (
_valid_lens := [
spu.micrometer,
spu.nanometer,
spu.picometer,
spu.angstrom,
spu.millimeter,
spu.centimeter,
spu.meter,
spu.inch,
spu.foot,
spu.yard,
spu.mile,
]
),
PT.Area: [_unit**2 for _unit in _valid_lens],
PT.Volume: [_unit**3 for _unit in _valid_lens],
# Mechanical
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
PT.Mass: [
spu.kilogram,
spu.electron_rest_mass,
spu.dalton,
spu.microgram,
spu.milligram,
spu.gram,
spu.metric_ton,
],
PT.Force: [
spux.micronewton,
spux.nanonewton,
spux.millinewton,
spu.newton,
spu.kg * spu.meter / spu.second**2,
],
PT.Pressure: [
spu.bar,
spux.millibar,
spu.pascal,
spux.hectopascal,
spu.atmosphere,
spu.psi,
spu.mmHg,
spu.torr,
],
# Energy
PT.Work: [
spu.joule,
spu.electronvolt,
],
PT.Power: [
spu.watt,
],
PT.PowerFlux: [
spu.watt / spu.meter**2,
],
PT.Temp: [
spu.kelvin,
],
# Electrodynamics
PT.Current: [
spu.ampere,
],
PT.CurrentDensity: [
spu.ampere / spu.meter**2,
],
PT.Charge: [
spu.coulomb,
],
PT.Voltage: [
spu.volt,
],
PT.Capacitance: [
spu.farad,
],
PT.Impedance: [
spu.ohm,
],
PT.Conductance: [
spu.siemens,
],
PT.Conductivity: [
spu.siemens / spu.micrometer,
spu.siemens / spu.meter,
],
PT.MFlux: [
spu.weber,
],
PT.MFluxDensity: [
spu.tesla,
],
PT.Inductance: [
spu.henry,
],
PT.EField: [
spu.volt / spu.micrometer,
spu.volt / spu.meter,
],
PT.HField: [
spu.ampere / spu.micrometer,
spu.ampere / spu.meter,
],
# Luminal
PT.LumIntensity: [
spu.candela,
],
PT.LumFlux: [
spu.candela * spu.steradian,
],
PT.Illuminance: [
spu.candela / spu.meter**2,
],
}[self]
@functools.cached_property
def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]:
"""All shapes with physical meaning in the context of a particular unit dimension."""
PT = PhysicalType
overrides = {
# Cartesian
PT.Length: [None, (2,), (3,)],
# Mechanical
PT.Vel: [None, (2,), (3,)],
PT.Accel: [None, (2,), (3,)],
PT.Force: [None, (2,), (3,)],
# Energy
PT.Work: [None, (2,), (3,)],
PT.PowerFlux: [None, (2,), (3,)],
# Electrodynamics
PT.CurrentDensity: [None, (2,), (3,)],
PT.MFluxDensity: [None, (2,), (3,)],
PT.EField: [None, (2,), (3,)],
PT.HField: [None, (2,), (3,)],
# Luminal
PT.LumFlux: [None, (2,), (3,)],
}
return overrides.get(self, [None])
@functools.cached_property
def valid_mathtypes(self) -> list[MathType]:
"""Returns a list of valid mathematical types, especially whether it can be real- or complex-valued.
Generally, all unit quantities are real, in the algebraic mathematical sense.
However, in electrodynamics especially, it becomes enormously useful to bake in a _rotational component_ as an imaginary value, be it simply to model phase or oscillation-oriented dampening.
This imaginary part has physical meaning, which can be expressed using the same mathematical formalism associated with unit systems.
In general, the value is a phasor.
While it is difficult to arrive at a well-defined way of saying, "this is when a quantity is complex", an attempt has been made to form a sensible baseline based on when phasor math may apply.
Notes:
- **Freq**/**AngFreq**: The imaginary part represents growth/dampening of the oscillation.
- **Current**/**Voltage**: The imaginary part represents the phase.
This also holds for any downstream units.
- **Charge**: Generally, it is real.
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
"""
MT = MathType
PT = PhysicalType
overrides = {
PT.NonPhysical: list(MT), ## Support All
# Cartesian
PT.Freq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
PT.AngFreq: [MT.Real, MT.Complex], ## Im -> Growth/Damping
# Mechanical
# Energy
# Electrodynamics
PT.Current: [MT.Real, MT.Complex], ## Im -> Phase
PT.CurrentDensity: [MT.Real, MT.Complex], ## Im -> Phase
PT.Charge: [MT.Real, MT.Complex], ## Im -> Phase
PT.Voltage: [MT.Real, MT.Complex], ## Im -> Phase
PT.Capacitance: [MT.Real, MT.Complex], ## Im -> Phase
PT.Impedance: [MT.Real, MT.Complex], ## Im -> Reactance
PT.Inductance: [MT.Real, MT.Complex], ## Im -> Extinction
PT.Conductance: [MT.Real, MT.Complex], ## Im -> Extinction
PT.Conductivity: [MT.Real, MT.Complex], ## Im -> Extinction
PT.MFlux: [MT.Real, MT.Complex], ## Im -> Phase
PT.MFluxDensity: [MT.Real, MT.Complex], ## Im -> Phase
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
# Luminal
}
return overrides.get(self, [MT.Real])
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
if value is PhysicalType.NonPhysical:
return 'Unitless'
return PhysicalType(value).name
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
PT = PhysicalType
return (
str(self),
PT.to_name(self),
PT.to_name(self),
PT.to_icon(self),
i,
)
@functools.cached_property
def color(self):
"""A color corresponding to the physical type.
The color selections were initially generated using AI, as this is a rote task that's better adjusted than invented.
The LLM provided the following rationale for its choices:
> Non-Physical: Grey signifies neutrality and non-physical nature.
> Global:
> Time: Blue is often associated with calmness and the passage of time.
> Angle and Solid Angle: Different shades of blue and cyan suggest angular dimensions and spatial aspects.
> Frequency and Angular Frequency: Darker shades of blue to maintain the link to time.
> Cartesian:
> Length, Area, Volume: Shades of green to represent spatial dimensions, with intensity increasing with dimension.
> Mechanical:
> Velocity and Acceleration: Red signifies motion and dynamics, with lighter reds for related quantities.
> Mass: Dark red for the fundamental property.
> Force and Pressure: Shades of red indicating intensity.
> Energy:
> Work and Power: Orange signifies energy transformation, with lighter oranges for related quantities.
> Temperature: Yellow for heat.
> Electrodynamics:
> Current and related quantities: Cyan shades indicating flow.
> Voltage, Capacitance: Greenish and blueish cyan for electrical potential.
> Impedance, Conductance, Conductivity: Purples and magentas to signify resistance and conductance.
> Magnetic properties: Magenta shades for magnetism.
> Electric Field: Light blue.
> Magnetic Field: Grey, as it can be considered neutral in terms of direction.
> Luminal:
> Luminous properties: Yellows to signify light and illumination.
>
> This color mapping helps maintain intuitive connections for users interacting with these physical types.
"""
PT = PhysicalType
return {
PT.NonPhysical: (0.75, 0.75, 0.75, 1.0), # Light Grey: Non-physical
# Global
PT.Time: (0.5, 0.5, 1.0, 1.0), # Light Blue: Time
PT.Angle: (0.5, 0.75, 1.0, 1.0), # Light Blue: Angle
PT.SolidAngle: (0.5, 0.75, 0.75, 1.0), # Light Cyan: Solid Angle
PT.Freq: (0.5, 0.5, 0.9, 1.0), # Light Blue: Frequency
PT.AngFreq: (0.5, 0.5, 0.8, 1.0), # Light Blue: Angular Frequency
# Cartesian
PT.Length: (0.5, 1.0, 0.5, 1.0), # Light Green: Length
PT.Area: (0.6, 1.0, 0.6, 1.0), # Light Green: Area
PT.Volume: (0.7, 1.0, 0.7, 1.0), # Light Green: Volume
# Mechanical
PT.Vel: (1.0, 0.5, 0.5, 1.0), # Light Red: Velocity
PT.Accel: (1.0, 0.6, 0.6, 1.0), # Light Red: Acceleration
PT.Mass: (0.75, 0.5, 0.5, 1.0), # Light Red: Mass
PT.Force: (0.9, 0.5, 0.5, 1.0), # Light Red: Force
PT.Pressure: (1.0, 0.7, 0.7, 1.0), # Light Red: Pressure
# Energy
PT.Work: (1.0, 0.75, 0.5, 1.0), # Light Orange: Work
PT.Power: (1.0, 0.85, 0.5, 1.0), # Light Orange: Power
PT.PowerFlux: (1.0, 0.8, 0.6, 1.0), # Light Orange: Power Flux
PT.Temp: (1.0, 1.0, 0.5, 1.0), # Light Yellow: Temperature
# Electrodynamics
PT.Current: (0.5, 1.0, 1.0, 1.0), # Light Cyan: Current
PT.CurrentDensity: (0.5, 0.9, 0.9, 1.0), # Light Cyan: Current Density
PT.Charge: (0.5, 0.85, 0.85, 1.0), # Light Cyan: Charge
PT.Voltage: (0.5, 1.0, 0.75, 1.0), # Light Greenish Cyan: Voltage
PT.Capacitance: (0.5, 0.75, 1.0, 1.0), # Light Blueish Cyan: Capacitance
PT.Impedance: (0.6, 0.5, 0.75, 1.0), # Light Purple: Impedance
PT.Conductance: (0.7, 0.5, 0.8, 1.0), # Light Purple: Conductance
PT.Conductivity: (0.8, 0.5, 0.9, 1.0), # Light Purple: Conductivity
PT.MFlux: (0.75, 0.5, 0.75, 1.0), # Light Magenta: Magnetic Flux
PT.MFluxDensity: (
0.85,
0.5,
0.85,
1.0,
), # Light Magenta: Magnetic Flux Density
PT.Inductance: (0.8, 0.5, 0.8, 1.0), # Light Magenta: Inductance
PT.EField: (0.75, 0.75, 1.0, 1.0), # Light Blue: Electric Field
PT.HField: (0.75, 0.75, 0.75, 1.0), # Light Grey: Magnetic Field
# Luminal
PT.LumIntensity: (1.0, 0.95, 0.5, 1.0), # Light Yellow: Luminous Intensity
PT.LumFlux: (1.0, 0.95, 0.6, 1.0), # Light Yellow: Luminous Flux
PT.Illuminance: (1.0, 1.0, 0.75, 1.0), # Pale Yellow: Illuminance
}[self]

View File

@ -0,0 +1,337 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from fractions import Fraction
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
import typing_extensions as typx
from pydantic_core import core_schema as pyd_core_schema
from . import units as spux
from .sympy_type import SympyType
from .unit_analysis import get_units, uses_units
####################
# - Pydantic "Sympy Expr"
####################
class _SympyExpr:
"""Low-level `pydantic`, schema describing how to serialize/deserialize fields that have a `SympyType` (like `sp.Expr`), so we can cleanly use `sympy` types in `pyd.BaseModel`.
Notes:
You probably want to use `SympyExpr`.
Examples:
To be usable as a type annotation on `pyd.BaseModel`, attach this to `SympyType` using `typx.Annotated`:
```python
SympyExpr = typx.Annotated[SympyType, _SympyExpr]
class Spam(pyd.BaseModel):
line: SympyExpr = sp.Eq(sp.y, 2*sp.Symbol(x, real=True) - 3)
```
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: SympyType,
_handler: pyd.GetCoreSchemaHandler,
) -> pyd_core_schema.CoreSchema:
"""Compute a schema that allows `pydantic` to validate a `sympy` type."""
def validate_from_str(sp_str: str | typ.Any) -> SympyType | typ.Any:
"""Parse and validate a string expression.
Parameters:
sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
Before use, `isinstance(expr_str, str)` is checked.
If the object isn't a string, then the validation will be skipped.
Returns:
Either a `sympy` object, if the input is parseable, or the same untouched object.
Raises:
ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
"""
# Constrain to String
if not isinstance(sp_str, str):
return sp_str
# Parse String -> Sympy
try:
expr = sp.sympify(sp_str)
except ValueError as ex:
msg = f'String {sp_str} is not a valid sympy expression'
raise ValueError(msg) from ex
# Substitute Symbol -> Quantity
return expr.subs(spux.UNIT_BY_SYMBOL)
def validate_from_pytype(
sp_pytype: int | Fraction | float | complex,
) -> SympyType | typ.Any:
"""Parse and validate a pure Python type.
Parameters:
sp_str: A stringified `sympy` object, that will be parsed to a sympy type.
Before use, `isinstance(expr_str, str)` is checked.
If the object isn't a string, then the validation will be skipped.
Returns:
Either a `sympy` object, if the input is parseable, or the same untouched object.
Raises:
ValueError: If `sp_str` is a string, but can't be parsed into a `sympy` expression.
"""
# Constrain to String
if not isinstance(sp_pytype, int | Fraction | float | complex):
return sp_pytype
if isinstance(sp_pytype, int):
return sp.Integer(sp_pytype)
if isinstance(sp_pytype, Fraction):
return sp.Rational(sp_pytype.numerator, sp_pytype.denominator)
if isinstance(sp_pytype, float):
return sp.Float(sp_pytype)
# sp_pytype => Complex
return sp_pytype.real + sp.I * sp_pytype.imag
sympy_expr_schema = pyd_core_schema.chain_schema(
[
pyd_core_schema.no_info_plain_validator_function(validate_from_str),
pyd_core_schema.no_info_plain_validator_function(validate_from_pytype),
pyd_core_schema.is_instance_schema(SympyType),
]
)
return pyd_core_schema.json_or_python_schema(
json_schema=sympy_expr_schema,
python_schema=sympy_expr_schema,
serialization=pyd_core_schema.plain_serializer_function_ser_schema(
lambda sp_obj: sp.srepr(sp_obj)
),
)
SympyExpr = typx.Annotated[
sp.Basic, ## Treat all sympy types as sp.Basic
_SympyExpr,
]
## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate.
def ConstrSympyExpr( # noqa: N802, PLR0913
# Features
allow_variables: bool = True,
allow_units: bool = True,
# Structures
allowed_sets: set[typ.Literal['integer', 'rational', 'real', 'complex']]
| None = None,
allowed_structures: set[typ.Literal['scalar', 'matrix']] | None = None,
# Element Class
max_symbols: int | None = None,
allowed_symbols: set[sp.Symbol] | None = None,
allowed_units: set[spu.Quantity] | None = None,
# Shape Class
allowed_matrix_shapes: set[tuple[int, int]] | None = None,
) -> SympyType:
"""Constructs a `SympyExpr` type, which will validate `sympy` types when used in a `pyd.BaseModel`.
Relies on the `sympy` assumptions system.
See <https://docs.sympy.org/latest/guides/assumptions.html#predicates>
Parameters (TBD):
Returns:
A type that represents a constrained `sympy` expression.
"""
def validate_expr(expr: SympyType):
if not (isinstance(expr, SympyType),):
msg = f"expr '{expr}' is not an allowed Sympy expression ({SympyType})"
raise ValueError(msg)
msgs = set()
# Validate Feature Class
if (not allow_variables) and (len(expr.free_symbols) > 0):
msgs.add(
f'allow_variables={allow_variables} does not match expression {expr}.'
)
if (not allow_units) and uses_units(expr):
msgs.add(f'allow_units={allow_units} does not match expression {expr}.')
# Validate Structure Class
if (
allowed_sets
and isinstance(expr, sp.Expr)
and not any(
{
'integer': expr.is_integer,
'rational': expr.is_rational,
'real': expr.is_real,
'complex': expr.is_complex,
}[allowed_set]
for allowed_set in allowed_sets
)
):
msgs.add(
f"allowed_sets={allowed_sets} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
if allowed_structures and not any(
{
'scalar': True,
'matrix': isinstance(expr, sp.MatrixBase),
}[allowed_set]
for allowed_set in allowed_structures
):
msgs.add(
f"allowed_structures={allowed_structures} does not match expression {expr} (remember to add assumptions to symbols, ex. `x = sp.Symbol('x', real=True))"
)
# Validate Element Class
if max_symbols and len(expr.free_symbols) > max_symbols:
msgs.add(f'max_symbols={max_symbols} does not match expression {expr}')
if allowed_symbols and expr.free_symbols.issubset(allowed_symbols):
msgs.add(
f'allowed_symbols={allowed_symbols} does not match expression {expr}'
)
if allowed_units and get_units(expr).issubset(allowed_units):
msgs.add(f'allowed_units={allowed_units} does not match expression {expr}')
# Validate Shape Class
if (
allowed_matrix_shapes and isinstance(expr, sp.MatrixBase)
) and expr.shape not in allowed_matrix_shapes:
msgs.add(
f'allowed_matrix_shapes={allowed_matrix_shapes} does not match expression {expr} with shape {expr.shape}'
)
# Error or Return
if msgs:
raise ValueError(str(msgs))
return expr
return typx.Annotated[
sp.Basic,
_SympyExpr,
pyd.AfterValidator(validate_expr),
]
####################
# - Common ConstrSympyExpr
####################
# Expression
ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_structures={'scalar'},
allowed_sets={'integer', 'rational', 'real'},
)
ScalarUnitlessComplexExpr: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_structures={'scalar'},
allowed_sets={'integer', 'rational', 'real', 'complex'},
)
# Symbol
IntSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer'},
max_symbols=1,
)
RationalSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational'},
max_symbols=1,
)
RealSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
max_symbols=1,
)
ComplexSymbol: typ.TypeAlias = ConstrSympyExpr(
allow_variables=True,
allow_units=False,
allowed_sets={'integer', 'rational', 'real', 'complex'},
max_symbols=1,
)
Symbol: typ.TypeAlias = IntSymbol | RealSymbol | ComplexSymbol
# Unit
UnitDimension: typ.TypeAlias = SympyExpr ## Actually spu.Dimension
## Technically a "unit expression", which includes compound types.
## Support for this is the reason to prefer over raw spu.Quantity.
Unit: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_structures={'scalar'},
)
# Number
IntNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer'},
allowed_structures={'scalar'},
)
RealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
# Number
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
# Vector
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=False,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'matrix'},
allowed_matrix_shapes={(3, 1)},
)

View File

@ -0,0 +1,23 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import sympy as sp
import sympy.physics.units as spu
####################
# - Underlying "Sympy Type"
####################
SympyType = sp.Basic | sp.MatrixBase | spu.Quantity | spu.Dimension

View File

@ -0,0 +1,287 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Functions for characterizaiton, conversion and casting of `sympy` objects that use units."""
import functools
import sympy as sp
import sympy.physics.units as spu
from . import units as spux
from .parse_cast import sympy_to_python
from .sympy_type import SympyType
####################
# - Unit Characterization
####################
## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def uses_units(sp_obj: SympyType) -> bool:
"""Determines if an expression uses any units.
Parameters:
expr: The sympy object that may contain units.
Returns:
Whether or not units are present in the object.
"""
return sp_obj.has(spu.Quantity)
## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def get_units(expr: sp.Expr) -> set[spu.Quantity]:
"""Finds all units used by the expression, and returns them as a set.
No information about _the relationship between units_ is exposed.
For example, compound units like `spu.meter / spu.second` would be mapped to `{spu.meter, spu.second}`.
Notes:
The expression graph is traversed depth-first with `sp.postorder_traversal`, to search for `sp.Quantity` elements.
The performance is comparable to the performance of `sp.postorder_traversal`, since the **entire expression graph will always be traversed**, with the added overhead of one `isinstance` call per expression-graph-node.
Parameters:
expr: The sympy expression that may contain units.
Returns:
All units (`spu.Quantity`) used within the expression.
"""
return {
subexpr
for subexpr in sp.postorder_traversal(expr)
if isinstance(subexpr, spu.Quantity)
}
####################
# - Dimensional Characterization
####################
def unit_dim_to_unit_dim_deps(
unit_dims: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
"""Normalize an expression to a mapping of its dimensional dependencies.
Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent.
Notes:
We adhere to SI unit conventions when determining dimensional dependencies, to ensure that ex. `freq -> 1/time` equivalences are normalized away.
This allows the output of this method to be compared meaningfully, to determine whether two dimensional expressions are equivalent.
We choose to catch a `TypeError`, for cases where dimensional analysis is impossible (especially `+` or `-` between differing dimensions).
This may have a slight performance penalty.
Returns:
The dimensional dependencies of the dimensional expression.
If such a thing makes no sense, ex. if `+` or `-` is present between differing unit dimensions, then return None.
"""
dimsys_SI = spu.systems.si.dimsys_SI
# Retrieve Dimensional Dependencies
try:
return dimsys_SI.get_dimensional_dependencies(unit_dims)
# Catch TypeError
## -> Happens if `+` or `-` is in `unit`.
## -> Generally, it doesn't make sense to add/subtract differing unit dims.
## -> Thus, when trying to figure out the unit dimension, there isn't one.
except TypeError:
return None
def unit_to_unit_dim_deps(
unit: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
"""Deduce the dimensional dependencies of a unit.
Notes:
Using `.subs()` to replace `sp.Quantity`s with `spu.dimensions.Dimension`s seems to result in an expression that absolutely refuses to claim that it has anything other than raw `sp.Symbol`s.
This is extremely problematic - dimensional analysis relies on the arithmetic properties of proper `Dimension` objects.
For this reason, though we'd rather have a `unit_to_unit_dims()` function, we have not yet found a way to do this.
Luckily, most of our uses cases seem only to require the dimensional dictionary, which (surprisingly) seems accessible using `unit_dim_to_unit_dim_deps()`.
"""
# Retrieve Dimensional Dependencies
## -> NOTE: .subs() alone seems to produce sp.Symbol atoms.
## -> This is extremely problematic; `Dims` arithmetic has key properties.
## -> So we have to go all the way to the dimensional dependencies.
## -> This isn't really respecting the args, but it seems to work :)
return unit_dim_to_unit_dim_deps(
unit.subs({arg: arg.dimension for arg in unit.atoms(spu.Quantity)})
)
def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
"""Compare the dimensional dependencies of two unit dimensions.
Comparing the dimensional dependencies of two `unit_dims` is a meaningful way of determining whether they are equivalent.
"""
return unit_dim_to_unit_dim_deps(unit_dim_l) == unit_dim_to_unit_dim_deps(
unit_dim_r
)
def compare_units_by_unit_dims(unit_l: SympyType, unit_r: SympyType) -> bool:
"""Compare two units by their unit dimensions."""
return unit_to_unit_dim_deps(unit_l) == unit_to_unit_dim_deps(unit_r)
def compare_unit_dim_to_unit_dim_deps(
unit_dim: SympyType, unit_dim_deps: dict[spu.dimensions.Dimension, int]
) -> bool:
"""Compare the dimensional dependencies of unit dimensions to pre-defined unit dimensions."""
return unit_dim_to_unit_dim_deps(unit_dim) == unit_dim_deps
####################
# - Unit Casting
####################
def strip_units(sp_obj: SympyType) -> SympyType:
"""Strip all units by replacing them to `1`.
This is a rather unsafe method.
You probably shouldn't use it.
Warnings:
Absolutely no effort is made to determine whether stripping units is a _meaningful thing to do_.
For example, using `+` expressions of compatible dimension, but different units, is a clear mistake.
For example, `8*meter + 9*millimeter` strips to `8(1) + 9(1) = 17`, which is a garbage result.
The **user of this method** must themselves perform appropriate checks on th eobject before stripping units.
Parameters:
sp_obj: A sympy object that contains unit symbols.
**NOTE**: Unit symbols (from `sympy.physics.units`) are not _free_ symbols, in that they are not unknown.
Nonetheless, they are not _numbers_ either, and thus they cannot be used in a numerical expression.
Returns:
The sympy object with all unit symbols replaced by `1`, effectively extracting the unitless part of the object.
"""
return sp_obj.subs(spux.UNIT_TO_1)
def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType:
"""Convert a sympy object to the given unit.
Supports a unit of `None`, which simply causes the object to have its units stripped.
"""
if unit is None:
return strip_units(sp_obj)
return spu.convert_to(sp_obj, unit)
# msg = f'Sympy object "{sp_obj}" was scaled to the unit "{unit}" with the expectation that the result would be unitless, but the result "{unitless_expr}" has units "{get_units(unitless_expr)}"'
# raise ValueError(msg)
## TODO: Include sympy_to_python in 'scale_to' to match semantics of 'scale_to_unit_system'
## -- Introduce a 'strip_unit
def scale_to_unit(
sp_obj: SympyType,
unit: spu.Quantity | None,
cast_to_pytype: bool = False,
use_jax_array: bool = False,
) -> SympyType:
"""Convert an expression that uses units to a different unit, then strip all units, leaving only a unitless `sympy` value.
This is used whenever the unitless part of an expression is needed, but guaranteed expressed in a particular unit, aka. **unit system normalization**.
Notes:
The unitless output is still an `sp.Expr`, which may contain ex. symbols.
If you know that the output **should** work as a corresponding Python type (ex. `sp.Integer` vs. `int`), but it doesn't, you can use `sympy_to_python()` to produce a pure-Python type.
In this way, with a little care, broad compatiblity can be bridged between the `sympy.physics.units` unit system and the wider Python ecosystem.
Parameters:
expr: The unit-containing expression to convert.
unit_to: The unit that is converted to.
Returns:
The unitless part of `expr`, after scaling the entire expression to `unit`.
Raises:
ValueError: If the result of unit-conversion and -stripping still has units, as determined by `uses_units()`.
"""
sp_obj_stripped = strip_units(convert_to_unit(sp_obj, unit))
if cast_to_pytype:
return sympy_to_python(
sp_obj_stripped,
use_jax_array=use_jax_array,
)
return sp_obj_stripped
def scaling_factor(
unit_from: SympyType, unit_to: SympyType
) -> int | float | complex | tuple | None:
"""Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another.
Parameters:
unit_from: The unit that is converted from.
unit_to: The unit that is converted to.
Returns:
The numerical scaling factor between the two units.
If the units are incompatible, then we return None.
Raises:
ValueError: If the two units don't share a common dimension.
"""
if compare_units_by_unit_dims(unit_from, unit_to):
return scale_to_unit(unit_from, unit_to)
return None
@functools.cache
def unit_str_to_unit(unit_str: str, optional: bool = False) -> SympyType | None:
"""Determine the `sympy` unit expression that matches the given unit string.
Parameters:
unit_str: A string parseable with `sp.sympify`, which contains a unit expression.
optional: Whether to return
**NOTE**: `None` is itself a valid "unit", denoting dimensionlessness, in general.
Ensure that appropriate checks are performed to account for this nuance.
Returns:
The matching `sympy` unit.
Raises:
ValueError: When no valid unit can be matched to the unit string, and `optional` is `False`.
"""
match unit_str:
# Special-Case 'degree'
## -> sp.sympify('degree') produces the sp.degree().
## -> TODO: Proper Analysis analysis.
case 'degree':
unit = spu.degree
case _:
unit = sp.sympify(unit_str).subs(spux.UNIT_BY_SYMBOL)
if uses_units(unit):
return unit
if optional:
return None
msg = f'No valid unit for unit string {unit_str}'
raise ValueError(msg)

View File

@ -0,0 +1,93 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Functions for conversion and casting of `sympy` objects that use units, via unit systems."""
import jax
import sympy.physics.units as spu
from . import units as spux
from .parse_cast import sympy_to_python
from .physical_type import PhysicalType
from .sympy_type import SympyType
from .unit_analysis import get_units
from .unit_systems import UnitSystem
####################
# - Conversion
####################
def strip_unit_system(
sp_obj: SympyType, unit_system: UnitSystem | None = None
) -> SympyType:
"""Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
Obviously, the semantic correctness of this operation depends entirely on _the units adding no semantic meaning to the expression_.
Notes:
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
"""
if unit_system is None:
return sp_obj.subs(spux.UNIT_TO_1)
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def convert_to_unit_system(
sp_obj: SympyType, unit_system: UnitSystem | None
) -> SympyType:
"""Convert an expression to the units of a given unit system."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
####################
# - Casting
####################
def scale_to_unit_system(
sp_obj: SympyType,
unit_system: UnitSystem | None,
use_jax_array: bool = False,
) -> int | float | complex | tuple | jax.Array:
"""Convert an expression to the units of a given unit system, then strip all units of the unit system.
Afterwards, it is converted to an appropriate Python type.
Notes:
For stability, and performance, reasons, this should only be used at the very last stage.
Regarding performance: **This is not a fast function**.
Parameters:
sp_obj: An arbitrary sympy object, presumably with units.
unit_system: A unit system mapping `PhysicalType` to particular choices of (compound) units.
Note that, in this context, only `unit_system.values()` is used.
Returns:
An appropriate pure Python type, after scaling to the unit system and stripping all units away.
If the returned type is array-like, and `use_jax_array` is specified, then (and **only** then) will a `jax.Array` be returned instead of a nested `tuple`.
"""
return sympy_to_python(
strip_unit_system(convert_to_unit_system(sp_obj, unit_system), unit_system),
use_jax_array=use_jax_array,
)

View File

@ -0,0 +1,80 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Defines a common unit system representation, as well as a few of the most common / useful unit systems.
Attributes:
UnitSystem: Type of a unit system representation, as an exhaustive mapping from `PhysicalType` to a unit expression.
**Compatibility between `PhysicalType` and unit must be manually guaranteed** when defining new unit systems.
UNITS_SI: Pre-defined go-to choice of unit system, which can also be a useful base to build other unit systems on.
"""
import typing as typ
import sympy.physics.units as spu
from . import units as spux
from .physical_type import PhysicalType as PT # noqa: N817
from .sympy_expr import Unit
####################
# - Unit System Representation
####################
UnitSystem: typ.TypeAlias = dict[PT, Unit]
####################
# - Standard Unit Systems
####################
UNITS_SI: UnitSystem = {
PT.NonPhysical: None,
# Global
PT.Time: spu.second,
PT.Angle: spu.radian,
PT.SolidAngle: spu.steradian,
PT.Freq: spu.hertz,
PT.AngFreq: spu.radian * spu.hertz,
# Cartesian
PT.Length: spu.meter,
PT.Area: spu.meter**2,
PT.Volume: spu.meter**3,
# Mechanical
PT.Vel: spu.meter / spu.second,
PT.Accel: spu.meter / spu.second**2,
PT.Mass: spu.kilogram,
PT.Force: spu.newton,
# Energy
PT.Work: spu.joule,
PT.Power: spu.watt,
PT.PowerFlux: spu.watt / spu.meter**2,
PT.Temp: spu.kelvin,
# Electrodynamics
PT.Current: spu.ampere,
PT.CurrentDensity: spu.ampere / spu.meter**2,
PT.Voltage: spu.volt,
PT.Capacitance: spu.farad,
PT.Impedance: spu.ohm,
PT.Conductance: spu.siemens,
PT.Conductivity: spu.siemens / spu.meter,
PT.MFlux: spu.weber,
PT.MFluxDensity: spu.tesla,
PT.Inductance: spu.henry,
PT.EField: spu.volt / spu.meter,
PT.HField: spu.ampere / spu.meter,
# Luminal
PT.LumIntensity: spu.candela,
PT.LumFlux: spux.lumen,
PT.Illuminance: spu.lux,
}

View File

@ -0,0 +1,77 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
import sympy as sp
import sympy.physics.units as spu
####################
# - Units
####################
# Time
femtosecond = fs = spu.Quantity('femtosecond', abbrev='fs')
femtosecond.set_global_relative_scale_factor(spu.femto, spu.second)
# Length
femtometer = fm = spu.Quantity('femtometer', abbrev='fm')
femtometer.set_global_relative_scale_factor(spu.femto, spu.meter)
# Lum Flux
lumen = lm = spu.Quantity('lumen', abbrev='lm')
lumen.set_global_relative_scale_factor(1, spu.candela * spu.steradian)
# Force
nanonewton = nN = spu.Quantity('nanonewton', abbrev='nN') # noqa: N816
nanonewton.set_global_relative_scale_factor(spu.nano, spu.newton)
micronewton = uN = spu.Quantity('micronewton', abbrev='μN') # noqa: N816
micronewton.set_global_relative_scale_factor(spu.micro, spu.newton)
millinewton = mN = spu.Quantity('micronewton', abbrev='mN') # noqa: N816
micronewton.set_global_relative_scale_factor(spu.milli, spu.newton)
# Frequency
kilohertz = KHz = spu.Quantity('kilohertz', abbrev='KHz')
kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
megahertz = MHz = spu.Quantity('megahertz', abbrev='MHz')
kilohertz.set_global_relative_scale_factor(spu.kilo, spu.hertz)
gigahertz = GHz = spu.Quantity('gigahertz', abbrev='GHz')
gigahertz.set_global_relative_scale_factor(spu.giga, spu.hertz)
terahertz = THz = spu.Quantity('terahertz', abbrev='THz')
terahertz.set_global_relative_scale_factor(spu.tera, spu.hertz)
petahertz = PHz = spu.Quantity('petahertz', abbrev='PHz')
petahertz.set_global_relative_scale_factor(spu.peta, spu.hertz)
exahertz = EHz = spu.Quantity('exahertz', abbrev='EHz')
exahertz.set_global_relative_scale_factor(spu.exa, spu.hertz)
# Pressure
millibar = mbar = spu.Quantity('millibar', abbrev='mbar')
millibar.set_global_relative_scale_factor(spu.milli, spu.bar)
hectopascal = hPa = spu.Quantity('hectopascal', abbrev='hPa') # noqa: N816
hectopascal.set_global_relative_scale_factor(spu.hecto, spu.pascal)
UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}