refactor: revamped symbolic flow (inaccurate unit conversions)
parent
353a2c997e
commit
bcba444a8b
|
@ -18,8 +18,8 @@ from .array import ArrayFlow
|
||||||
from .capabilities import CapabilitiesFlow
|
from .capabilities import CapabilitiesFlow
|
||||||
from .flow_kinds import FlowKind
|
from .flow_kinds import FlowKind
|
||||||
from .info import InfoFlow
|
from .info import InfoFlow
|
||||||
from .lazy_range import RangeFlow, ScalingMode
|
|
||||||
from .lazy_func import FuncFlow
|
from .lazy_func import FuncFlow
|
||||||
|
from .lazy_range import RangeFlow, ScalingMode
|
||||||
from .params import ParamsFlow
|
from .params import ParamsFlow
|
||||||
from .value import ValueFlow
|
from .value import ValueFlow
|
||||||
|
|
||||||
|
|
|
@ -50,11 +50,6 @@ class ArrayFlow:
|
||||||
####################
|
####################
|
||||||
# - Computed Properties
|
# - Computed Properties
|
||||||
####################
|
####################
|
||||||
@property
|
|
||||||
def is_symbolic(self) -> bool:
|
|
||||||
"""Always False, as ArrayFlows are never unrealized."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Outer length of the contained array."""
|
"""Outer length of the contained array."""
|
||||||
return len(self.values)
|
return len(self.values)
|
||||||
|
@ -196,5 +191,11 @@ class ArrayFlow:
|
||||||
"""
|
"""
|
||||||
return self.rescale(lambda v: v, new_unit=new_unit)
|
return self.rescale(lambda v: v, new_unit=new_unit)
|
||||||
|
|
||||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
def rescale_to_unit_system(self, unit_system: spux.UnitSystem | None) -> typ.Self:
|
||||||
raise NotImplementedError
|
if unit_system is None:
|
||||||
|
return self.values
|
||||||
|
|
||||||
|
return self.correct_unit(None).rescale(
|
||||||
|
lambda v: spux.scale_to_unit_system(v * self.unit, unit_system),
|
||||||
|
new_unit=spux.convert_to_unit_system(self.unit, unit_system),
|
||||||
|
)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
from ..socket_types import SocketType
|
from ..socket_types import SocketType
|
||||||
from .flow_kinds import FlowKind
|
from .flow_kinds import FlowKind
|
||||||
|
@ -25,6 +26,7 @@ from .flow_kinds import FlowKind
|
||||||
class CapabilitiesFlow:
|
class CapabilitiesFlow:
|
||||||
socket_type: SocketType
|
socket_type: SocketType
|
||||||
active_kind: FlowKind
|
active_kind: FlowKind
|
||||||
|
allow_out_to_in: dict[FlowKind, FlowKind] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
is_universal: bool = False
|
is_universal: bool = False
|
||||||
|
|
||||||
|
@ -40,7 +42,13 @@ class CapabilitiesFlow:
|
||||||
def is_compatible_with(self, other: typ.Self) -> bool:
|
def is_compatible_with(self, other: typ.Self) -> bool:
|
||||||
return other.is_universal or (
|
return other.is_universal or (
|
||||||
self.socket_type == other.socket_type
|
self.socket_type == other.socket_type
|
||||||
and self.active_kind == other.active_kind
|
and (
|
||||||
|
self.active_kind == other.active_kind
|
||||||
|
or (
|
||||||
|
other.active_kind in other.allow_out_to_in
|
||||||
|
and self.active_kind == other.allow_out_to_in[other.active_kind]
|
||||||
|
)
|
||||||
|
)
|
||||||
# == Constraint
|
# == Constraint
|
||||||
and all(
|
and all(
|
||||||
name in other.must_match
|
name in other.must_match
|
||||||
|
|
|
@ -67,8 +67,9 @@ class InfoFlow:
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Access
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def last_dim(self) -> sim_symbols.SimSymbol | None:
|
def first_dim(self) -> sim_symbols.SimSymbol | None:
|
||||||
"""The integer axis occupied by the dimension.
|
"""The integer axis occupied by the dimension.
|
||||||
|
|
||||||
Can be used to index `.shape` of the represented raw array.
|
Can be used to index `.shape` of the represented raw array.
|
||||||
|
@ -87,13 +88,24 @@ class InfoFlow:
|
||||||
return list(self.dims.keys())[-1]
|
return list(self.dims.keys())[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
|
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
|
||||||
|
if idx > 0 and idx < len(self.dims) - 1:
|
||||||
|
return list(self.dims.keys())[idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def dim_by_name(self, dim_name: str) -> int:
|
||||||
"""The integer axis occupied by the dimension.
|
"""The integer axis occupied by the dimension.
|
||||||
|
|
||||||
Can be used to index `.shape` of the represented raw array.
|
Can be used to index `.shape` of the represented raw array.
|
||||||
"""
|
"""
|
||||||
return list(self.dims.keys()).index(dim)
|
dims_with_name = [dim for dim in self.dims if dim.name == dim_name]
|
||||||
|
if len(dims_with_name) == 1:
|
||||||
|
return dims_with_name[0]
|
||||||
|
|
||||||
|
msg = f'Dim name {dim_name} not found in InfoFlow (or >1 found)'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Information By-Dim
|
||||||
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
|
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||||
"""Whether the dim's index is continuous, and therefore index array.
|
"""Whether the dim's index is continuous, and therefore index array.
|
||||||
|
|
||||||
|
@ -114,6 +126,23 @@ class InfoFlow:
|
||||||
return isinstance(self.dims[dim], list)
|
return isinstance(self.dims[dim], list)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||||
|
"""Whether the (int) dim has explicitly uniform indexing.
|
||||||
|
|
||||||
|
This is needed primarily to check whether a Fourier Transform can be meaningfully performed on the data over the dimension's axis.
|
||||||
|
|
||||||
|
In practice, we've decided that only `RangeFlow` really truly _guarantees_ uniform indexing.
|
||||||
|
While `ArrayFlow` may be uniform in practice, it's a very expensive to check, and it's far better to enforce that the user perform that check and opt for a `RangeFlow` instead, at the time of dimension definition.
|
||||||
|
"""
|
||||||
|
return isinstance(self.dims[dim], RangeFlow) and self.dims[dim].scaling == 'lin'
|
||||||
|
|
||||||
|
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
|
||||||
|
"""The integer axis occupied by the dimension.
|
||||||
|
|
||||||
|
Can be used to index `.shape` of the represented raw array.
|
||||||
|
"""
|
||||||
|
return list(self.dims.keys()).index(dim)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output: Contravariant Value
|
# - Output: Contravariant Value
|
||||||
####################
|
####################
|
||||||
|
@ -128,6 +157,49 @@ class InfoFlow:
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties
|
||||||
|
####################
|
||||||
|
@functools.cached_property
|
||||||
|
def input_mathtypes(self) -> tuple[spux.MathType, ...]:
|
||||||
|
return tuple([dim.mathtype for dim in self.dims])
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def output_mathtypes(self) -> tuple[spux.MathType, int, int]:
|
||||||
|
return [self.output.mathtype for _ in range(len(self.output.shape) + 1)]
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def order(self) -> tuple[spux.MathType, ...]:
|
||||||
|
r"""The order of the tensor represented by this info.
|
||||||
|
|
||||||
|
While that sounds fancy and all, it boils down to:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\texttt{dims} + |\texttt{output}.\texttt{shape}|
|
||||||
|
$$
|
||||||
|
|
||||||
|
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
|
||||||
|
"""
|
||||||
|
return len(self.input_mathtypes) + self.output_shape_len
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties
|
||||||
|
####################
|
||||||
|
@functools.cached_property
|
||||||
|
def dim_labels(self) -> dict[str, dict[str, str]]:
|
||||||
|
"""Return a dictionary mapping pretty dim names to information oriented for columnar information display."""
|
||||||
|
return {
|
||||||
|
dim.name_pretty: {
|
||||||
|
'length': str(len(dim_idx)) if dim_idx is not None else '∞',
|
||||||
|
'mathtype': dim.mathtype.label_pretty,
|
||||||
|
'unit': dim.unit_label,
|
||||||
|
}
|
||||||
|
for dim, dim_idx in self.dims.items()
|
||||||
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Operations: Dimensions
|
# - Operations: Dimensions
|
||||||
####################
|
####################
|
||||||
|
@ -147,9 +219,11 @@ class InfoFlow:
|
||||||
"""Slice a dimensional array by-index along a particular dimension."""
|
"""Slice a dimensional array by-index along a particular dimension."""
|
||||||
return InfoFlow(
|
return InfoFlow(
|
||||||
dims={
|
dims={
|
||||||
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
_dim: (
|
||||||
|
dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||||
if _dim == dim
|
if _dim == dim
|
||||||
else _dim
|
else dim_idx
|
||||||
|
)
|
||||||
for _dim, dim_idx in self.dims.items()
|
for _dim, dim_idx in self.dims.items()
|
||||||
},
|
},
|
||||||
output=self.output,
|
output=self.output,
|
||||||
|
@ -166,7 +240,7 @@ class InfoFlow:
|
||||||
return InfoFlow(
|
return InfoFlow(
|
||||||
dims={
|
dims={
|
||||||
(new_dim if _dim == old_dim else _dim): (
|
(new_dim if _dim == old_dim else _dim): (
|
||||||
new_dim_idx if _dim == old_dim else _dim
|
new_dim_idx if _dim == old_dim else dim_idx
|
||||||
)
|
)
|
||||||
for _dim, dim_idx in self.dims.items()
|
for _dim, dim_idx in self.dims.items()
|
||||||
},
|
},
|
||||||
|
@ -235,6 +309,26 @@ class InfoFlow:
|
||||||
pinned_values=self.pinned_values,
|
pinned_values=self.pinned_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def operate_output(
|
||||||
|
self,
|
||||||
|
other: typ.Self,
|
||||||
|
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
|
||||||
|
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
|
||||||
|
) -> spux.SympyExpr:
|
||||||
|
if self.dims == other.dims:
|
||||||
|
sym_name = sim_symbols.SimSymbolName.Expr
|
||||||
|
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
|
||||||
|
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
|
||||||
|
|
||||||
|
return InfoFlow(
|
||||||
|
dims=self.dims,
|
||||||
|
output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
|
||||||
|
pinned_values=self.pinned_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'InfoFlow: operate_output cannot be used when dimensions are not identical ({self.dims} | {other.dims}).'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Operations: Fold
|
# - Operations: Fold
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -22,7 +22,7 @@ from types import MappingProxyType
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
|
||||||
from .params import ParamsFlow
|
from .params import ParamsFlow
|
||||||
|
|
||||||
|
@ -244,17 +244,10 @@ class FuncFlow:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
func: LazyFunction
|
func: LazyFunction
|
||||||
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
|
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
||||||
default_factory=list
|
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
|
||||||
)
|
|
||||||
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
|
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
## TODO: Use SimSymbol instead of the MathType|PT union.
|
|
||||||
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
|
|
||||||
## -- SimSymbol has more semantic meaning, including a name.
|
|
||||||
## -- If desired, SimSymbols could maybe even require a specific unit.
|
|
||||||
## It could greatly simplify a whole lot of pain associated with func_args.
|
|
||||||
supports_jax: bool = False
|
supports_jax: bool = False
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -315,17 +308,18 @@ class FuncFlow:
|
||||||
def realize(
|
def realize(
|
||||||
self,
|
self,
|
||||||
params: ParamsFlow,
|
params: ParamsFlow,
|
||||||
unit_system: spux.UnitSystem | None = None,
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
{}
|
||||||
|
),
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
if self.supports_jax:
|
if self.supports_jax:
|
||||||
return self.func_jax(
|
return self.func_jax(
|
||||||
*params.scaled_func_args(unit_system, symbol_values),
|
*params.scaled_func_args(self.func_args, symbol_values),
|
||||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
*params.scaled_func_kwargs(self.func_args, symbol_values),
|
||||||
)
|
)
|
||||||
return self.func(
|
return self.func(
|
||||||
*params.scaled_func_args(unit_system, symbol_values),
|
*params.scaled_func_args(self.func_kwargs, symbol_values),
|
||||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
*params.scaled_func_kwargs(self.func_kwargs, symbol_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -18,17 +18,18 @@ import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
from fractions import Fraction
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
|
||||||
from .array import ArrayFlow
|
from .array import ArrayFlow
|
||||||
from .lazy_func import FuncFlow
|
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
@ -62,7 +63,7 @@ class ScalingMode(enum.StrEnum):
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
class RangeFlow:
|
class RangeFlow:
|
||||||
r"""Represents a spaced array using symbolic boundary expressions.
|
r"""Represents a finite spaced array using symbolic boundary expressions.
|
||||||
|
|
||||||
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||||
|
|
||||||
|
@ -79,33 +80,76 @@ class RangeFlow:
|
||||||
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
|
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
start: An expression generating a scalar, unitless, complex value for the array's lower bound.
|
start: An expression representing the unitless part of the finite, scalar, complex value for the array's lower bound.
|
||||||
_Integer, rational, and real values are also supported._
|
_Integer, rational, and real values are also supported._
|
||||||
stop: An expression generating a scalar, unitless, complex value for the array's upper bound.
|
start: An expression representing the unitless part of the finite, scalar, complex value for the array's upper bound.
|
||||||
_Integer, rational, and real values are also supported._
|
_Integer, rational, and real values are also supported._
|
||||||
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
|
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
|
||||||
scaling: The method of distributing `step` values between the two endpoints.
|
scaling: The method of distributing `step` values between the two endpoints.
|
||||||
Generally, the linear default is sufficient.
|
|
||||||
|
|
||||||
unit: The unit of the generated array values
|
unit: The unit to interpret the values as.
|
||||||
|
|
||||||
symbols: Set of variables from which `start` and/or `stop` are determined.
|
symbols: Set of variables from which `start` and/or `stop` are determined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start: spux.ScalarUnitlessComplexExpr
|
start: spux.ScalarUnitlessComplexExpr
|
||||||
stop: spux.ScalarUnitlessComplexExpr
|
stop: spux.ScalarUnitlessComplexExpr
|
||||||
steps: int
|
steps: int = 0
|
||||||
scaling: ScalingMode = ScalingMode.Lin
|
scaling: ScalingMode = ScalingMode.Lin
|
||||||
|
|
||||||
unit: spux.Unit | None = None
|
unit: spux.Unit | None = None
|
||||||
|
|
||||||
symbols: frozenset[spux.Symbol] = frozenset()
|
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||||
|
|
||||||
|
# Helper Attributes
|
||||||
|
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Computed Properties
|
# - SimSymbol Interop
|
||||||
|
####################
|
||||||
|
@staticmethod
|
||||||
|
def from_sym(
|
||||||
|
sym: sim_symbols.SimSymbol,
|
||||||
|
steps: int = 50,
|
||||||
|
scaling: ScalingMode | str = ScalingMode.Lin,
|
||||||
|
) -> typ.Self:
|
||||||
|
if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
|
||||||
|
use_steps = 0
|
||||||
|
else:
|
||||||
|
use_steps = steps
|
||||||
|
|
||||||
|
return RangeFlow(
|
||||||
|
start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
|
||||||
|
stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
|
||||||
|
steps=use_steps,
|
||||||
|
scaling=ScalingMode(scaling),
|
||||||
|
unit=sym.unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sym(
|
||||||
|
self,
|
||||||
|
sym_name: sim_symbols.SimSymbolName,
|
||||||
|
) -> typ.Self:
|
||||||
|
physical_type = spux.PhysicalType.from_unit(self.unit, optional=True)
|
||||||
|
|
||||||
|
return sim_symbols.SimSymbol(
|
||||||
|
sym_name=sym_name,
|
||||||
|
mathtype=self.mathtype,
|
||||||
|
physical_type=(
|
||||||
|
physical_type
|
||||||
|
if physical_type is not None
|
||||||
|
else spux.PhysicalType.NonPhysical
|
||||||
|
),
|
||||||
|
unit=self.unit,
|
||||||
|
rows=1,
|
||||||
|
cols=1,
|
||||||
|
).set_domain(start=self.realize_start(), end=self.realize_end())
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Symbols
|
||||||
####################
|
####################
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
The order is guaranteed to be **deterministic**.
|
The order is guaranteed to be **deterministic**.
|
||||||
|
@ -115,10 +159,21 @@ class RangeFlow:
|
||||||
"""
|
"""
|
||||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
def is_symbolic(self) -> bool:
|
def sorted_sp_symbols(self) -> list[spux.Symbol]:
|
||||||
"""Whether the `RangeFlow` has unrealized symbols."""
|
"""Computes `sympy` symbols from `self.sorted_symbols`.
|
||||||
return len(self.symbols) > 0
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
return [sym.sp_symbol for sym in self.sorted_symbols]
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties
|
||||||
|
####################
|
||||||
|
@functools.cached_property
|
||||||
|
def unit_factor(self) -> spux.SympyExpr:
|
||||||
|
return self.unit if self.unit is not None else sp.S(1)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Compute the length of the array that would be realized.
|
"""Compute the length of the array that would be realized.
|
||||||
|
@ -166,6 +221,14 @@ class RangeFlow:
|
||||||
####################
|
####################
|
||||||
# - Methods
|
# - Methods
|
||||||
####################
|
####################
|
||||||
|
@property
|
||||||
|
def ideal_midpoint(self) -> spux.SympyExpr:
|
||||||
|
return (self.stop + self.start) / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ideal_range(self) -> spux.SympyExpr:
|
||||||
|
return self.stop - self.start
|
||||||
|
|
||||||
def rescale(
|
def rescale(
|
||||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
|
@ -181,8 +244,8 @@ class RangeFlow:
|
||||||
new_pre_start = self.start if not reverse else self.stop
|
new_pre_start = self.start if not reverse else self.stop
|
||||||
new_pre_stop = self.stop if not reverse else self.start
|
new_pre_stop = self.stop if not reverse else self.start
|
||||||
|
|
||||||
new_start = rescale_func(new_pre_start * self.unit)
|
new_start = rescale_func(new_pre_start * self.unit_factor)
|
||||||
new_stop = rescale_func(new_pre_stop * self.unit)
|
new_stop = rescale_func(new_pre_stop * self.unit_factor)
|
||||||
|
|
||||||
return RangeFlow(
|
return RangeFlow(
|
||||||
start=(
|
start=(
|
||||||
|
@ -204,6 +267,99 @@ class RangeFlow:
|
||||||
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def bound_fourier_transform(self):
|
||||||
|
r"""Treat this `RangeFlow` it as an axis along which a fourier transform is being performed, such that its bounds scale according to the Nyquist Limit.
|
||||||
|
|
||||||
|
# Sampling Theory
|
||||||
|
In general, the Fourier Transform is an operator that works on infinite, periodic, continuous functions.
|
||||||
|
In this context alone it is a ideal transform (in terms of information retention), one which degrades quite gracefully in the face of practicalities like windowing (allowing them to apply analytically to non-periodic functions too).
|
||||||
|
While often used to transform an axis of time to an axis of frequency, in general this transform "simply" extracts all repeating structures from a function.
|
||||||
|
This is illustrated beautifully in the way that the output unit becomes the reciprocal of the input unit, which is the theory underlying why we say that measurements recieved as a reciprocal unit are in a "reciprocal space" (also called "k-space").
|
||||||
|
|
||||||
|
The real world is not so nice, of course, and as such we must generally make do with the Discrete Fourier Transform.
|
||||||
|
Even with bounded discrete information, we can annoy many mathematicians by defining a DFT in such a way that "structure per thing" ($\frac{1}{\texttt{unit}}$) still makes sense to us (to them, maybe not).
|
||||||
|
A DFT can still only retain the information given to it, but so long as we have enough "original structure", any "repeating structure" should be extractable with sufficient clarity to be useful.
|
||||||
|
|
||||||
|
What "sufficient clarity" means is the basis for the entire field of "sampling theory".
|
||||||
|
The theoretical maximum for the "fineness of repetition" that is "noticeable" in the fourier-transformed of some data is characterized by a theoretical upper bound called the Nyquist Frequency / Limit, which ends up being half of the sampling rate.
|
||||||
|
Thus, to determine bounds on the data, use of the Nyquist Limit is generally a good starting point.
|
||||||
|
|
||||||
|
Of course, when the discrete data comes from a discretization of a continuous signal, information from higher frequencies might still affect the discrete results.
|
||||||
|
They do little else than cause havoc, though - best causing noise, and at worst causing structured artifacts (sometimes called "aliasing").
|
||||||
|
Some of the first innovations in sampling theory were related to "anti-aliasing" filters, whose sole purpose is to try to remove informational frequencies above the Nyquist Limit of the discrete sensor.
|
||||||
|
|
||||||
|
In FDTD simulation, we're generally already ahead when it comes to aliasing, since our field values come from an already-discrete process.
|
||||||
|
That is, unless we start overly "binning" (averaging over $n$ discrete observations); in this case, care should be taken to make sure that interesting results aren't merely unfortunately structured aliasing artifacts.
|
||||||
|
|
||||||
|
For more, see <https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem>
|
||||||
|
|
||||||
|
# Implementation
|
||||||
|
In practice, our goal in `RangeFlow` is to compute the bounds of the index array along the fourier-transformed data axis.
|
||||||
|
The reciprocal of the unit will be taken (when unitless, `1/1=`).
|
||||||
|
The raw Nyquist Limit $n_q$ will be used to bound the unitless part of the output as $[-n_q, n_q]$
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `self.scaling` is not linear, since the FT can only be performed on uniformly spaced data.
|
||||||
|
"""
|
||||||
|
if self.scaling is ScalingMode.Lin:
|
||||||
|
nyquist_limit = self.steps / self.ideal_range
|
||||||
|
|
||||||
|
# Return New Bounds w/Nyquist Theorem
|
||||||
|
## -> The Nyquist Limit describes "max repeated info per sample".
|
||||||
|
## -> Information can still record "faster" than the Nyquist Limit.
|
||||||
|
## -> It will just be either noise (best case), or banded artifacts.
|
||||||
|
## -> This is called "aliasing", and it's best to try and filter it.
|
||||||
|
## -> Sims generally "bin" yee cells, which sacrifices some NyqLim.
|
||||||
|
return RangeFlow(
|
||||||
|
start=-nyquist_limit,
|
||||||
|
stop=nyquist_limit,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=1 / self.unit if self.unit is not None else None,
|
||||||
|
pre_fourier_ideal_midpoint=self.ideal_midpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def bound_inv_fourier_transform(self):
|
||||||
|
r"""Treat this `RangeFlow` as an axis along which an inverse fourier transform is being performed, such that its Nyquist-limit bounds are transformed back into values along the original unit dimension.
|
||||||
|
|
||||||
|
See `self.bound_fourier_transform` for the theoretical concepts.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
**The discrete inverse fourier transform always centers its output at $0$**.
|
||||||
|
|
||||||
|
Of course, it's entirely probable that the original signal was not centered at $0$.
|
||||||
|
For this reason, when performing a Fourier transform, `self.bound_fourier_transform` sets a special variable, `self.pre_fourier_ideal_midpoint`.
|
||||||
|
When set, it will retain the `self.ideal_midpoint` around which both `self.start` and `self.stop` should be centered after an inverse FT.
|
||||||
|
|
||||||
|
If `self.pre_fourier_ideal_midpoint` is set, then it will be used as the midpoint of the output's `start`/`stop`.
|
||||||
|
Otherwise, $0$ will be used - in which case the user should themselves, manually, shift the output if needed.
|
||||||
|
"""
|
||||||
|
if self.scaling is ScalingMode.Lin:
|
||||||
|
orig_ideal_range = self.steps / self.ideal_range
|
||||||
|
|
||||||
|
orig_start_centered = -orig_ideal_range
|
||||||
|
orig_stop_centered = orig_ideal_range
|
||||||
|
orig_ideal_midpoint = (
|
||||||
|
self.pre_fourier_ideal_midpoint
|
||||||
|
if self.pre_fourier_ideal_midpoint is not None
|
||||||
|
else sp.S(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return New Bounds w/Inverse of Nyquist Theorem
|
||||||
|
return RangeFlow(
|
||||||
|
start=-orig_start_centered + orig_ideal_midpoint,
|
||||||
|
stop=orig_stop_centered + orig_ideal_midpoint,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=1 / self.unit if self.unit is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Exporters
|
# - Exporters
|
||||||
####################
|
####################
|
||||||
|
@ -237,15 +393,15 @@ class RangeFlow:
|
||||||
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
|
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
The ordering of the symbols is identical to `self.sorted_symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
|
A function that generates a 1D numerical array equivalent to the range represented in this `RangeFlow`.
|
||||||
"""
|
"""
|
||||||
# Compile JAX Functions for Start/End Expressions
|
# Compile JAX Functions for Start/End Expressions
|
||||||
## -> FYI, JAX-in-JAX works perfectly fine.
|
## -> FYI, JAX-in-JAX works perfectly fine.
|
||||||
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
|
start_jax = sp.lambdify(self.sorted_sp_symbols, self.start, 'jax')
|
||||||
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
|
stop_jax = sp.lambdify(self.sorted_sp_symbols, self.stop, 'jax')
|
||||||
|
|
||||||
# Compile ArrayGen Function
|
# Compile ArrayGen Function
|
||||||
def gen_array(
|
def gen_array(
|
||||||
|
@ -256,54 +412,80 @@ class RangeFlow:
|
||||||
# Return ArrayGen Function
|
# Return ArrayGen Function
|
||||||
return gen_array
|
return gen_array
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def as_lazy_func(self) -> FuncFlow:
|
|
||||||
"""Creates a `FuncFlow` using the output of `self.as_func`.
|
|
||||||
|
|
||||||
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
|
|
||||||
"""
|
|
||||||
return FuncFlow(
|
|
||||||
func=self.as_func,
|
|
||||||
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Realization
|
# - Realization
|
||||||
####################
|
####################
|
||||||
|
def realize_symbols(
|
||||||
|
self,
|
||||||
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
|
{}
|
||||||
|
),
|
||||||
|
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
|
||||||
|
"""Realize **all** input symbols to the `RangeFlow`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
symbol_values: A scalar, unitless, complex `sympy` expression for each symbol defined in `self.symbols`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary directly usable in expression substitutions using `sp.Basic.subs()`.
|
||||||
|
"""
|
||||||
|
if self.symbols == set(symbol_values.keys()):
|
||||||
|
realized_syms = {}
|
||||||
|
for sym in self.sorted_symbols:
|
||||||
|
sym_value = symbol_values[sym]
|
||||||
|
|
||||||
|
# Sympy Expression
|
||||||
|
## -> We need to conform the expression to the SimSymbol.
|
||||||
|
## -> Mainly, this is
|
||||||
|
if (
|
||||||
|
isinstance(sym_value, spux.SympyType)
|
||||||
|
and not isinstance(sym_value, sp.MatrixBase)
|
||||||
|
and not spux.uses_units(sym_value)
|
||||||
|
):
|
||||||
|
v = sym.conform(sym_value)
|
||||||
|
else:
|
||||||
|
msg = f'RangeFlow: No realization support for symbolic value {sym_value} (type={type(sym_value)})'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
realized_syms |= {sym: v}
|
||||||
|
|
||||||
|
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def realize_start(
|
def realize_start(
|
||||||
self,
|
self,
|
||||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
|
{}
|
||||||
|
),
|
||||||
) -> int | float | complex:
|
) -> int | float | complex:
|
||||||
"""Realize the start-bound by inserting particular values for each symbol."""
|
"""Realize the start-bound by inserting particular values for each symbol."""
|
||||||
return spux.sympy_to_python(
|
realized_symbols = self.realize_symbols(symbol_values)
|
||||||
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
return spux.sympy_to_python(self.start.subs(realized_symbols))
|
||||||
)
|
|
||||||
|
|
||||||
def realize_stop(
|
def realize_stop(
|
||||||
self,
|
self,
|
||||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
|
{}
|
||||||
|
),
|
||||||
) -> int | float | complex:
|
) -> int | float | complex:
|
||||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||||
return spux.sympy_to_python(
|
realized_symbols = self.realize_symbols(symbol_values)
|
||||||
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
return spux.sympy_to_python(self.stop.subs(realized_symbols))
|
||||||
)
|
|
||||||
|
|
||||||
def realize_step_size(
|
def realize_step_size(
|
||||||
self,
|
self,
|
||||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
|
{}
|
||||||
|
),
|
||||||
) -> int | float | complex:
|
) -> int | float | complex:
|
||||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||||
if self.scaling is not ScalingMode.Lin:
|
if self.scaling is not ScalingMode.Lin:
|
||||||
raise NotImplementedError('Non-linear scaling mode not yet suported')
|
msg = 'Non-linear scaling mode not yet suported'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
|
raw_step_size = (
|
||||||
|
self.realize_stop(symbol_values) - self.realize_start(symbol_values) + 1
|
||||||
|
) / self.steps
|
||||||
|
|
||||||
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
|
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
|
||||||
return int(raw_step_size)
|
return int(raw_step_size)
|
||||||
|
@ -311,7 +493,9 @@ class RangeFlow:
|
||||||
|
|
||||||
def realize(
|
def realize(
|
||||||
self,
|
self,
|
||||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
|
{}
|
||||||
|
),
|
||||||
) -> ArrayFlow:
|
) -> ArrayFlow:
|
||||||
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
|
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
|
||||||
|
|
||||||
|
@ -324,16 +508,34 @@ class RangeFlow:
|
||||||
## TODO: Check symbol values for coverage.
|
## TODO: Check symbol values for coverage.
|
||||||
|
|
||||||
return ArrayFlow(
|
return ArrayFlow(
|
||||||
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
|
values=self.as_func(
|
||||||
|
*[
|
||||||
|
spux.scale_to_unit_system(symbol_values[sym])
|
||||||
|
for sym in self.sorted_symbols
|
||||||
|
]
|
||||||
|
),
|
||||||
unit=self.unit,
|
unit=self.unit,
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def realize_array(self) -> ArrayFlow:
|
def realize_array(self) -> ArrayFlow:
|
||||||
"""Standardized access to `self.realize()` when there are no symbols."""
|
"""Standardized access to `self.realize()` when there are no symbols.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are symbols defined in `self.symbols`.
|
||||||
|
"""
|
||||||
|
if not self.symbols:
|
||||||
return self.realize()
|
return self.realize()
|
||||||
|
|
||||||
|
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
|
||||||
|
"""Alias for `realize_array.values`."""
|
||||||
|
return self.realize_array.values
|
||||||
|
|
||||||
def __getitem__(self, subscript: slice):
|
def __getitem__(self, subscript: slice):
|
||||||
"""Implement indexing and slicing in a sane way.
|
"""Implement indexing and slicing in a sane way.
|
||||||
|
|
||||||
|
@ -379,12 +581,6 @@ class RangeFlow:
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
"""
|
"""
|
||||||
if self.unit is not None:
|
|
||||||
log.debug(
|
|
||||||
'%s: Corrected unit to %s',
|
|
||||||
self,
|
|
||||||
corrected_unit,
|
|
||||||
)
|
|
||||||
return RangeFlow(
|
return RangeFlow(
|
||||||
start=self.start,
|
start=self.start,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
|
@ -394,9 +590,6 @@ class RangeFlow:
|
||||||
symbols=self.symbols,
|
symbols=self.symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
||||||
"""Replaces the unit, **with** rescaling of the bounds.
|
"""Replaces the unit, **with** rescaling of the bounds.
|
||||||
|
|
||||||
|
@ -410,11 +603,6 @@ class RangeFlow:
|
||||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
"""
|
"""
|
||||||
if self.unit is not None:
|
if self.unit is not None:
|
||||||
log.debug(
|
|
||||||
'%s: Scaled to unit %s',
|
|
||||||
self,
|
|
||||||
unit,
|
|
||||||
)
|
|
||||||
return RangeFlow(
|
return RangeFlow(
|
||||||
start=spux.scale_to_unit(self.start * self.unit, unit),
|
start=spux.scale_to_unit(self.start * self.unit, unit),
|
||||||
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
||||||
|
@ -423,11 +611,18 @@ class RangeFlow:
|
||||||
unit=unit,
|
unit=unit,
|
||||||
symbols=self.symbols,
|
symbols=self.symbols,
|
||||||
)
|
)
|
||||||
|
return RangeFlow(
|
||||||
|
start=self.start * unit,
|
||||||
|
stop=self.stop * unit,
|
||||||
|
steps=self.steps,
|
||||||
|
scaling=self.scaling,
|
||||||
|
unit=unit,
|
||||||
|
symbols=self.symbols,
|
||||||
|
)
|
||||||
|
|
||||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
def rescale_to_unit_system(
|
||||||
raise ValueError(msg)
|
self, unit_system: spux.UnitSystem | None = None
|
||||||
|
) -> typ.Self:
|
||||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
|
||||||
"""Replaces the units, **with** rescaling of the bounds.
|
"""Replaces the units, **with** rescaling of the bounds.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -439,28 +634,11 @@ class RangeFlow:
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||||
"""
|
"""
|
||||||
if self.unit is not None:
|
|
||||||
log.debug(
|
|
||||||
'%s: Scaled to new unit system (new unit = %s)',
|
|
||||||
self,
|
|
||||||
unit_system[spux.PhysicalType.from_unit(self.unit)],
|
|
||||||
)
|
|
||||||
return RangeFlow(
|
return RangeFlow(
|
||||||
start=spux.strip_unit_system(
|
start=spux.scale_to_unit_system(self.start * self.unit, unit_system),
|
||||||
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
stop=spux.scale_to_unit_system(self.stop * self.unit, unit_system),
|
||||||
unit_system,
|
|
||||||
),
|
|
||||||
stop=spux.strip_unit_system(
|
|
||||||
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
|
|
||||||
unit_system,
|
|
||||||
),
|
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
scaling=self.scaling,
|
scaling=self.scaling,
|
||||||
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
|
unit=spux.convert_to_unit_system(self.unit, unit_system),
|
||||||
symbols=self.symbols,
|
symbols=self.symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = (
|
|
||||||
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
|
@ -17,15 +17,19 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
from fractions import Fraction
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
import jaxtyping as jtyp
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger, sim_symbols
|
from blender_maxwell.utils import logger, sim_symbols
|
||||||
|
|
||||||
|
from .array import ArrayFlow
|
||||||
from .expr_info import ExprInfo
|
from .expr_info import ExprInfo
|
||||||
from .flow_kinds import FlowKind
|
from .flow_kinds import FlowKind
|
||||||
|
from .lazy_range import RangeFlow
|
||||||
|
|
||||||
# from .info import InfoFlow
|
# from .info import InfoFlow
|
||||||
|
|
||||||
|
@ -34,13 +38,22 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
class ParamsFlow:
|
class ParamsFlow:
|
||||||
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
|
||||||
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
||||||
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Symbols
|
||||||
|
####################
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -48,52 +61,179 @@ class ParamsFlow:
|
||||||
"""
|
"""
|
||||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
|
||||||
|
"""Computes `sympy` symbols from `self.sorted_symbols`.
|
||||||
|
|
||||||
|
When the output is shaped, a single shaped symbol (`sp.MatrixSymbol`) is used to represent the symbolic name and shaping.
|
||||||
|
This choice is made due to `MatrixSymbol`'s compatibility with `.lambdify` JIT.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All symbols valid for use in the expression.
|
||||||
|
"""
|
||||||
|
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - JIT'ed Callables for Numerical Function Arguments
|
||||||
|
####################
|
||||||
|
def func_args_n(
|
||||||
|
self, target_syms: list[sim_symbols.SimSymbol]
|
||||||
|
) -> list[
|
||||||
|
typ.Callable[
|
||||||
|
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||||
|
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
"""Callable functions for evaluating each `self.func_args` entry numerically.
|
||||||
|
|
||||||
|
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `target_syms`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
|
||||||
|
|
||||||
|
This ensures conformance to the `SimSymbol` properties (like units), as well as adherance to a numerical type identity compatible with `sp.lambdify()`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
target_syms: `SimSymbol`s describing how a particular `ParamsFlow` function argument should be scaled when performing a purely numerical insertion.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
sp.lambdify(
|
||||||
|
self.sorted_sp_symbols,
|
||||||
|
target_sym.conform(func_arg, strip_unit=True),
|
||||||
|
'jax',
|
||||||
|
)
|
||||||
|
for func_arg, target_sym in zip(self.func_args, target_syms, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
def func_kwargs_n(
|
||||||
|
self, target_syms: dict[str, sim_symbols.SimSymbol]
|
||||||
|
) -> dict[
|
||||||
|
str,
|
||||||
|
typ.Callable[
|
||||||
|
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||||
|
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
"""Callable functions for evaluating each `self.func_kwargs` entry numerically.
|
||||||
|
|
||||||
|
The arguments of each function **must** be pre-treated using `SimSymbol.scale()`.
|
||||||
|
This ensures conformance to the `SimSymbol` properties, as well as adherance to a numerical type identity compatible with `sp.lambdify()`
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
func_arg_key: sp.lambdify(
|
||||||
|
self.sorted_sp_symbols,
|
||||||
|
target_syms[func_arg_key].scale(func_arg),
|
||||||
|
'jax',
|
||||||
|
)
|
||||||
|
for func_arg_key, func_arg in self.func_kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Realization
|
||||||
|
####################
|
||||||
|
def realize_symbols(
|
||||||
|
self,
|
||||||
|
symbol_values: dict[
|
||||||
|
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||||
|
] = MappingProxyType({}),
|
||||||
|
) -> dict[
|
||||||
|
sp.Symbol,
|
||||||
|
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
|
||||||
|
]:
|
||||||
|
"""Fully realize all symbols by assigning them a value.
|
||||||
|
|
||||||
|
Three kinds of values for `symbol_values` are supported, fundamentally:
|
||||||
|
|
||||||
|
- **Sympy Expression**: When the value is a sympy expression with units, the unit of the `SimSymbol` key which unit the value if converted to.
|
||||||
|
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||||
|
- **Range**: When the value is a `RangeFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
|
||||||
|
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||||
|
- **Array**: When the value is an `ArrayFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
|
||||||
|
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary almost with `.subs()`, other than `jax` arrays.
|
||||||
|
"""
|
||||||
|
if set(self.symbols) == set(symbol_values.keys()):
|
||||||
|
realized_syms = {}
|
||||||
|
for sym in self.sorted_symbols:
|
||||||
|
sym_value = symbol_values[sym]
|
||||||
|
|
||||||
|
if isinstance(sym_value, spux.SympyType):
|
||||||
|
v = sym.scale(sym_value)
|
||||||
|
|
||||||
|
elif isinstance(sym_value, ArrayFlow | RangeFlow):
|
||||||
|
v = sym_value.rescale_to_unit(sym.unit).values
|
||||||
|
## NOTE: RangeFlow must not be symbolic.
|
||||||
|
|
||||||
|
else:
|
||||||
|
msg = f'No support for symbolic value {sym_value} (type={type(sym_value)})'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
realized_syms |= {sym: v}
|
||||||
|
|
||||||
|
return realized_syms
|
||||||
|
|
||||||
|
msg = f'ParamsFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Realize Arguments
|
# - Realize Arguments
|
||||||
####################
|
####################
|
||||||
def scaled_func_args(
|
def scaled_func_args(
|
||||||
self,
|
self,
|
||||||
unit_system: spux.UnitSystem | None = None,
|
target_syms: list[sim_symbols.SimSymbol] = (),
|
||||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
{}
|
{}
|
||||||
),
|
),
|
||||||
):
|
) -> list[
|
||||||
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
|
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
|
||||||
|
]:
|
||||||
|
"""Realize correctly conformed numerical arguments for `self.func_args`.
|
||||||
|
|
||||||
For all `arg`s in `self.func_args`, the following operations are performed.
|
Because we allow symbols to be used in `self.func_args`, producing a numerical value that can be passed directly to a `FuncFlow` becomes a two-step process:
|
||||||
|
|
||||||
Notes:
|
1. Conform Symbols: Arbitrary `sympy` expressions passed as `symbol_values` must first be conformed to match the ex. units of `SimSymbol`s found in `self.symbols`, before they can be used.
|
||||||
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
|
|
||||||
|
2. Conform Function Arguments: Arbitrary `sympy` expressions encoded in `self.func_args` must, **after** inserting the conformed numerical symbols, themselves be conformed to the expected ex. units of the function that they are to be used within.
|
||||||
|
**`ParamsFlow` doesn't contain information about the `SimSymbol`s that `self.func_args` are expected to conform to** (on purpose).
|
||||||
|
Therefore, the user is required to pass a `target_syms` with identical length to `self.func_args`, describing the `SimSymbol`s to conform the function arguments to.
|
||||||
|
|
||||||
|
Our implementation attempts to utilize simple, powerful primitives to accomplish this in roughly three steps:
|
||||||
|
|
||||||
|
1. **Realize Symbols**: Particular passed symbolic values `symbol_values`, which are arbitrary `sympy` expressions, are conformed to the definitions in `self.symbols` (ex. to match units), then cast to numerical values (pure Python / jax array).
|
||||||
|
|
||||||
|
2. **Lazy Function Arguments**: Stored function arguments `self.func_args`, which are arbitrary `sympy` expressions, are conformed to the definitions in `target_syms` (ex. to match units), then cast to numerical values (pure Python / jax array).
|
||||||
|
_Technically, this happens as part of `self.func_args_n`._
|
||||||
|
|
||||||
|
3. **Numerical Evaluation**: The numerical values for each symbol are passed as parameters to each (callable) element of `self.func_args_n`, which produces a correct numerical value for each function argument.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
target_syms: `SimSymbol`s describing how the function arguments returned by this method are intended to be used.
|
||||||
|
**Generally**, the parallel `FuncFlow.func_args` should be inserted here, and guarantees correct results when this output is inserted into `FuncFlow.func(...)`.
|
||||||
|
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `target_syms`).
|
||||||
"""
|
"""
|
||||||
if not all(sym in self.symbols for sym in symbol_values):
|
realized_symbols = list(self.realize_symbols(symbol_values).values())
|
||||||
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
|
|
||||||
return [
|
return [
|
||||||
(
|
func_arg_n(*realized_symbols)
|
||||||
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
|
for func_arg_n in self.func_args_n(target_syms)
|
||||||
if arg not in symbol_values
|
|
||||||
else symbol_values[arg]
|
|
||||||
)
|
|
||||||
for arg in self.func_args
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def scaled_func_kwargs(
|
def scaled_func_kwargs(
|
||||||
self,
|
self,
|
||||||
unit_system: spux.UnitSystem | None = None,
|
target_syms: list[sim_symbols.SimSymbol] = (),
|
||||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||||
):
|
) -> dict[
|
||||||
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
|
||||||
if not all(sym in self.symbols for sym in symbol_values):
|
]:
|
||||||
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
"""Realize correctly conformed numerical arguments for `self.func_kwargs`.
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
|
||||||
|
"""
|
||||||
|
realized_symbols = self.realize_symbols(symbol_values)
|
||||||
return {
|
return {
|
||||||
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
|
func_arg_name: func_arg_n(**realized_symbols)
|
||||||
if arg not in symbol_values
|
for func_arg_name, func_arg_n in self.func_kwargs_n(target_syms).items()
|
||||||
else symbol_values[arg]
|
|
||||||
for arg_name, arg in self.func_kwargs.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -129,8 +269,8 @@ class ParamsFlow:
|
||||||
####################
|
####################
|
||||||
# - Generate ExprSocketDef
|
# - Generate ExprSocketDef
|
||||||
####################
|
####################
|
||||||
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
|
def sym_expr_infos(self, use_range: bool = False) -> dict[str, ExprInfo]:
|
||||||
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
|
"""Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
|
||||||
|
|
||||||
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
|
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
|
||||||
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
|
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
|
||||||
|
@ -151,35 +291,22 @@ class ParamsFlow:
|
||||||
|
|
||||||
The `ExprInfo`s can be directly defererenced `**expr_info`)
|
The `ExprInfo`s can be directly defererenced `**expr_info`)
|
||||||
"""
|
"""
|
||||||
for sim_sym in self.sorted_symbols:
|
for sym in self.sorted_symbols:
|
||||||
if use_range and sim_sym.mathtype is spux.MathType.Complex:
|
if use_range and sym.mathtype is spux.MathType.Complex:
|
||||||
msg = 'No support for complex range in ExprInfo'
|
msg = 'No support for complex range in ExprInfo'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
|
if use_range and (sym.rows > 1 or sym.cols > 1):
|
||||||
msg = 'No support for non-scalar elements of range in ExprInfo'
|
msg = 'No support for non-scalar elements of range in ExprInfo'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
if sim_sym.rows > 3 or sim_sym.cols > 1:
|
if sym.rows > 3 or sym.cols > 1:
|
||||||
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
|
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sim_sym.name: {
|
sym.name: {
|
||||||
# Declare Kind/Size
|
|
||||||
## -> Kind: Value prevents user-alteration of config.
|
|
||||||
## -> Size: Always scalar, since symbols are scalar (for now).
|
|
||||||
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
|
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
|
||||||
'size': spux.NumberSize1D.Scalar,
|
|
||||||
# Declare MathType/PhysicalType
|
|
||||||
## -> MathType: Lookup symbol name in info dimensions.
|
|
||||||
## -> PhysicalType: Same.
|
|
||||||
'mathtype': self.dims[sim_sym].mathtype,
|
|
||||||
'physical_type': self.dims[sim_sym].physical_type,
|
|
||||||
# TODO: Default Value
|
|
||||||
# FlowKind.Value: Default Value
|
|
||||||
#'default_value':
|
|
||||||
# FlowKind.Range: Default Min/Max/Steps
|
|
||||||
'default_min': sim_sym.domain.start,
|
|
||||||
'default_max': sim_sym.domain.end,
|
|
||||||
'default_steps': 50,
|
'default_steps': 50,
|
||||||
}
|
}
|
||||||
for sim_sym in self.sorted_symbols
|
| sym.expr_info
|
||||||
|
for sym in self.sorted_symbols
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import tidy3d as td
|
||||||
|
|
||||||
from blender_maxwell.contracts import BLEnumElement
|
from blender_maxwell.contracts import BLEnumElement
|
||||||
from blender_maxwell.services import tdcloud
|
from blender_maxwell.services import tdcloud
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
from .flow_kinds.info import InfoFlow
|
from .flow_kinds.info import InfoFlow
|
||||||
|
@ -482,6 +483,93 @@ class DataFileFormat(enum.StrEnum):
|
||||||
## - When sidecars aren't found, the user would "fill in the blanks".
|
## - When sidecars aren't found, the user would "fill in the blanks".
|
||||||
## - ...Thus achieving the same result as if there were a sidecar.
|
## - ...Thus achieving the same result as if there were a sidecar.
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Functions: DataFrame
|
||||||
|
####################
|
||||||
|
@staticmethod
|
||||||
|
def to_df(
|
||||||
|
data: jtyp.Shaped[jtyp.Array, 'x_size y_size'], info: InfoFlow
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""Utility method to convert raw data to a `polars.DataFrame`, as guided by an `InfoFlow`.
|
||||||
|
|
||||||
|
Only works with 2D data (obviously).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the data has more than two dimensions, all `info` dimensions are not discrete/labelled, or the dimensionality of `info` doesn't match.
|
||||||
|
"""
|
||||||
|
if info.order > 2: # noqa: PLR2004
|
||||||
|
msg = f'Data may not have more than two dimensions (info={info}, data.shape={data.shape})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if any(info.has_idx_cont(dim) for dim in info.dims):
|
||||||
|
msg = f'To convert data|info to a dataframe, no dimensions can have continuous indices (info={info})'
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
data_np = np.array(data)
|
||||||
|
|
||||||
|
MT = spux.MathType
|
||||||
|
match (
|
||||||
|
info.input_mathtypes,
|
||||||
|
info.output.mathtype,
|
||||||
|
info.output.rows,
|
||||||
|
info.output.cols,
|
||||||
|
):
|
||||||
|
# (R,Z) -> Complex Scalar
|
||||||
|
## -> Polars (also pandas) doesn't have a complex type.
|
||||||
|
## -> Will be treated as (R, Z, 2) -> Real Scalar.
|
||||||
|
case ((MT.Rational | MT.Real, MT.Integer), MT.Complex, 1, 1):
|
||||||
|
row_dim = info.first_dim
|
||||||
|
col_dim = info.last_dim
|
||||||
|
|
||||||
|
return pl.DataFrame(
|
||||||
|
{row_dim.name: info.dims[row_dim]}
|
||||||
|
| {
|
||||||
|
col_label + postfix: re_im(data_np[:, col])
|
||||||
|
for col, col_label in enumerate(info.dims[col_dim])
|
||||||
|
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# (R,Z) -> Scalar
|
||||||
|
case ((MT.Rational | MT.Real, MT.Integer), _, 1, 1):
|
||||||
|
row_dim = info.first_dim
|
||||||
|
col_dim = info.last_dim
|
||||||
|
|
||||||
|
return pl.DataFrame(
|
||||||
|
{row_dim.name: info.dims[row_dim]}
|
||||||
|
| {
|
||||||
|
col_label: data_np[:, col]
|
||||||
|
for col, col_label in enumerate(info.dims[col_dim])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# (Z) -> Complex Vector/Covector
|
||||||
|
case ((MT.Integer,), MT.Complex, r, c) if (r > 1 and c == 1) or (
|
||||||
|
r == 1 and c > 1
|
||||||
|
):
|
||||||
|
col_dim = info.last_dim
|
||||||
|
|
||||||
|
return pl.DataFrame(
|
||||||
|
{
|
||||||
|
col_label + postfix: re_im(data_np[col, :])
|
||||||
|
for col, col_label in enumerate(info.dims[col_dim])
|
||||||
|
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# (Z) -> Real Vector
|
||||||
|
## -> Each integer index will be treated as a column index.
|
||||||
|
## -> This will effectively transpose the data.
|
||||||
|
case ((MT.Integer,), _, r, c) if (r > 1 and c == 1) or (r == 1 and c > 1):
|
||||||
|
col_dim = info.last_dim
|
||||||
|
|
||||||
|
return pl.DataFrame(
|
||||||
|
{
|
||||||
|
col_label: data_np[col, :]
|
||||||
|
for col, col_label in enumerate(info.dims[col_dim])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Functions: Saver
|
# - Functions: Saver
|
||||||
####################
|
####################
|
||||||
|
@ -506,50 +594,7 @@ class DataFileFormat(enum.StrEnum):
|
||||||
np.savetxt(path, data)
|
np.savetxt(path, data)
|
||||||
|
|
||||||
def save_csv(path, data, info):
|
def save_csv(path, data, info):
|
||||||
data_np = np.array(data)
|
df = self.to_df(data, info)
|
||||||
|
|
||||||
# Extract Input Coordinates
|
|
||||||
dim_columns = {
|
|
||||||
dim.name: np.array(dim_idx.realize_array)
|
|
||||||
for i, (dim, dim_idx) in enumerate(info.dims)
|
|
||||||
} ## TODO: realize_array might not be defined on some index arrays
|
|
||||||
|
|
||||||
# Declare Function to Extract Output Values
|
|
||||||
output_columns = {}
|
|
||||||
|
|
||||||
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
|
|
||||||
nonlocal output_columns
|
|
||||||
|
|
||||||
# Complex: Split to Two Columns
|
|
||||||
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
|
|
||||||
if bool(np.any(np.iscomplex(data_col))):
|
|
||||||
output_columns |= {
|
|
||||||
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
|
|
||||||
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Else: Use Array Directly
|
|
||||||
else:
|
|
||||||
output_columns |= {
|
|
||||||
f'{info.output.name}{output_idx_str}': data_col,
|
|
||||||
}
|
|
||||||
|
|
||||||
## TODO: Maybe a check to ensure dtype!=object?
|
|
||||||
|
|
||||||
# Extract Output Values
|
|
||||||
## -> 2D: Iterate over columns by-index.
|
|
||||||
## -> 1D: Declare the array as the only column.
|
|
||||||
if len(data_np.shape) == 2:
|
|
||||||
for output_idx in data_np.shape[1]:
|
|
||||||
declare_output_col(data_np[:, output_idx], output_idx, True)
|
|
||||||
else:
|
|
||||||
declare_output_col(data_np)
|
|
||||||
|
|
||||||
# Compute DataFrame & Write CSV
|
|
||||||
df = pl.DataFrame(dim_columns | output_columns)
|
|
||||||
|
|
||||||
log.debug('Writing Polars DataFrame to CSV:')
|
|
||||||
log.debug(df)
|
|
||||||
df.write_csv(path)
|
df.write_csv(path)
|
||||||
|
|
||||||
def save_npy(path, data, info):
|
def save_npy(path, data, info):
|
||||||
|
|
|
@ -264,17 +264,22 @@ class ManagedBLImage(base.ManagedObj):
|
||||||
# times = [time.perf_counter()]
|
# times = [time.perf_counter()]
|
||||||
|
|
||||||
# Compute Plot Dimensions
|
# Compute Plot Dimensions
|
||||||
aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||||
self.gen_image_geometry(width_inches, height_inches, dpi)
|
# self.gen_image_geometry(width_inches, height_inches, dpi)
|
||||||
)
|
# )
|
||||||
# times.append(['Image Geometry', time.perf_counter() - times[0]])
|
# times.append(['Image Geometry', time.perf_counter() - times[0]])
|
||||||
|
# log.critical(
|
||||||
|
# [aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px]
|
||||||
|
# )
|
||||||
|
|
||||||
# Create MPL Figure, Axes, and Compute Figure Geometry
|
# Create MPL Figure, Axes, and Compute Figure Geometry
|
||||||
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
|
# fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
|
||||||
_width_inches, _height_inches, _dpi
|
# _width_inches, _height_inches, _dpi
|
||||||
)
|
# )
|
||||||
|
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
|
||||||
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
|
# fig.clear()
|
||||||
ax.clear()
|
ax.clear()
|
||||||
# times.append(['Clear Axis', time.perf_counter() - times[0]])
|
# times.append(['Clear Axis', time.perf_counter() - times[0]])
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class FilterOperation(enum.StrEnum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Slice
|
# Slice
|
||||||
|
Slice = enum.auto()
|
||||||
SliceIdx = enum.auto()
|
SliceIdx = enum.auto()
|
||||||
|
|
||||||
# Pin
|
# Pin
|
||||||
|
@ -53,9 +54,8 @@ class FilterOperation(enum.StrEnum):
|
||||||
Pin = enum.auto()
|
Pin = enum.auto()
|
||||||
PinIdx = enum.auto()
|
PinIdx = enum.auto()
|
||||||
|
|
||||||
# Reinterpret
|
# Dimension
|
||||||
Swap = enum.auto()
|
Swap = enum.auto()
|
||||||
SetDim = enum.auto()
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
|
@ -65,14 +65,14 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
return {
|
return {
|
||||||
# Slice
|
# Slice
|
||||||
FO.SliceIdx: 'a[...]',
|
FO.Slice: '=a[i:j]',
|
||||||
|
FO.SliceIdx: '≈a[v₁:v₂]',
|
||||||
# Pin
|
# Pin
|
||||||
FO.PinLen1: 'pinₐ =1',
|
FO.PinLen1: 'pinₐ',
|
||||||
FO.Pin: 'pinₐ ≈v',
|
FO.Pin: 'pinₐ ≈v',
|
||||||
FO.PinIdx: 'pinₐ =a[v]',
|
FO.PinIdx: 'pinₐ =i',
|
||||||
# Reinterpret
|
# Reinterpret
|
||||||
FO.Swap: 'a₁ ↔ a₂',
|
FO.Swap: 'a₁ ↔ a₂',
|
||||||
FO.SetDim: 'setₐ =v',
|
|
||||||
}[value]
|
}[value]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -118,11 +118,6 @@ class FilterOperation(enum.StrEnum):
|
||||||
if len(info.dims) >= 2: # noqa: PLR2004
|
if len(info.dims) >= 2: # noqa: PLR2004
|
||||||
operations.append(FO.Swap)
|
operations.append(FO.Swap)
|
||||||
|
|
||||||
## SetDim
|
|
||||||
## -> There must be a dimension to correct.
|
|
||||||
if info.dims:
|
|
||||||
operations.append(FO.SetDim)
|
|
||||||
|
|
||||||
return operations
|
return operations
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -145,6 +140,7 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
return {
|
return {
|
||||||
# Slice
|
# Slice
|
||||||
|
FO.Slice: 1,
|
||||||
FO.SliceIdx: 1,
|
FO.SliceIdx: 1,
|
||||||
# Pin
|
# Pin
|
||||||
FO.PinLen1: 1,
|
FO.PinLen1: 1,
|
||||||
|
@ -152,40 +148,35 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO.PinIdx: 1,
|
FO.PinIdx: 1,
|
||||||
# Reinterpret
|
# Reinterpret
|
||||||
FO.Swap: 2,
|
FO.Swap: 2,
|
||||||
FO.SetDim: 1,
|
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
match self:
|
match self:
|
||||||
case FO.SliceIdx | FO.Swap:
|
# Slice
|
||||||
return info.dims
|
case FO.Slice:
|
||||||
|
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
|
||||||
|
|
||||||
# PinLen1: Only allow dimensions with length=1.
|
case FO.SliceIdx:
|
||||||
|
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
|
||||||
|
|
||||||
|
# Pin
|
||||||
case FO.PinLen1:
|
case FO.PinLen1:
|
||||||
return [
|
return [
|
||||||
dim
|
dim
|
||||||
for dim, dim_idx in info.dims.items()
|
for dim, dim_idx in info.dims.items()
|
||||||
if dim_idx is not None and len(dim_idx) == 1
|
if not info.has_idx_cont(dim) and len(dim_idx) == 1
|
||||||
]
|
]
|
||||||
|
|
||||||
# Pin: Only allow dimensions with discrete index.
|
case FO.Pin:
|
||||||
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
|
return info.dims
|
||||||
case FO.Pin | FO.PinIdx:
|
|
||||||
return [
|
|
||||||
dim
|
|
||||||
for dim, dim_idx in info.dims
|
|
||||||
if dim_idx is not None and len(dim_idx) > 0
|
|
||||||
]
|
|
||||||
|
|
||||||
case FO.SetDim:
|
case FO.PinIdx:
|
||||||
return [
|
return [dim for dim in info.dims if not info.has_idx_cont(dim)]
|
||||||
dim
|
|
||||||
for dim, dim_idx in info.dims
|
# Dimension
|
||||||
if dim_idx is not None
|
case FO.Swap:
|
||||||
and not isinstance(dim_idx, list)
|
return info.dims
|
||||||
and dim_idx.mathtype == spux.MathType.Integer
|
|
||||||
]
|
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -193,9 +184,14 @@ class FilterOperation(enum.StrEnum):
|
||||||
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
|
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
|
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
|
||||||
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or (
|
if self.num_dim_inputs == 1:
|
||||||
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info)
|
return dim_0 in self.valid_dims(info)
|
||||||
)
|
|
||||||
|
if self.num_dim_inputs == 2: # noqa: PLR2004
|
||||||
|
valid_dims = self.valid_dims(info)
|
||||||
|
return dim_0 in valid_dims and dim_1 in valid_dims
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
|
@ -209,6 +205,9 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
return {
|
return {
|
||||||
# Pin
|
# Pin
|
||||||
|
FO.Slice: lambda expr: jlax.slice_in_dim(
|
||||||
|
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
||||||
|
),
|
||||||
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
|
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
|
||||||
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
||||||
),
|
),
|
||||||
|
@ -216,9 +215,8 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
|
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
|
||||||
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||||
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||||
# Reinterpret
|
# Dimension
|
||||||
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
|
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
|
||||||
FO.SetDim: lambda expr: expr,
|
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
def transform_info(
|
def transform_info(
|
||||||
|
@ -228,10 +226,10 @@ class FilterOperation(enum.StrEnum):
|
||||||
dim_1: sim_symbols.SimSymbol,
|
dim_1: sim_symbols.SimSymbol,
|
||||||
pin_idx: int | None = None,
|
pin_idx: int | None = None,
|
||||||
slice_tuple: tuple[int, int, int] | None = None,
|
slice_tuple: tuple[int, int, int] | None = None,
|
||||||
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
|
|
||||||
):
|
):
|
||||||
FO = FilterOperation
|
FO = FilterOperation
|
||||||
return {
|
return {
|
||||||
|
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||||
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
|
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||||
# Pin
|
# Pin
|
||||||
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||||
|
@ -239,7 +237,6 @@ class FilterOperation(enum.StrEnum):
|
||||||
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||||
# Reinterpret
|
# Reinterpret
|
||||||
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
||||||
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
|
|
||||||
}[self]()
|
}[self]()
|
||||||
|
|
||||||
|
|
||||||
|
@ -330,8 +327,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
def search_dims(self) -> list[ct.BLEnumElement]:
|
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||||
if self.expr_info is not None and self.operation is not None:
|
if self.expr_info is not None and self.operation is not None:
|
||||||
return [
|
return [
|
||||||
(dim_name, dim_name, dim_name, '', i)
|
(dim.name, dim.name_pretty, dim.name, '', i)
|
||||||
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info))
|
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -380,8 +377,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
# Reinterpret
|
# Reinterpret
|
||||||
case FO.Swap:
|
case FO.Swap:
|
||||||
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
|
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
|
||||||
case FO.SetDim:
|
|
||||||
return f'Filter: Set [{self.active_dim_0}]'
|
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
return self.bl_label
|
return self.bl_label
|
||||||
|
@ -480,30 +475,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Loose Sockets: Set Dim
|
|
||||||
## -> The user must provide a (ℤ) -> ℝ array.
|
|
||||||
## -> It must be of identical length to the replaced axis.
|
|
||||||
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
|
|
||||||
dim = dim_0
|
|
||||||
current_bl_socket = self.loose_input_sockets.get('Dim')
|
|
||||||
if (
|
|
||||||
current_bl_socket is None
|
|
||||||
or current_bl_socket.active_kind != ct.FlowKind.Func
|
|
||||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
|
||||||
or current_bl_socket.mathtype != dim.mathtype
|
|
||||||
or current_bl_socket.physical_type != dim.physical_type
|
|
||||||
):
|
|
||||||
self.loose_input_sockets = {
|
|
||||||
'Dim': sockets.ExprSocketDef(
|
|
||||||
active_kind=ct.FlowKind.Func,
|
|
||||||
physical_type=dim.physical_type,
|
|
||||||
mathtype=dim.mathtype,
|
|
||||||
default_unit=dim.unit,
|
|
||||||
show_func_ui=False,
|
|
||||||
show_info_columns=True,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# No Loose Value: Remove Input Sockets
|
# No Loose Value: Remove Input Sockets
|
||||||
elif self.loose_input_sockets:
|
elif self.loose_input_sockets:
|
||||||
self.loose_input_sockets = {}
|
self.loose_input_sockets = {}
|
||||||
|
@ -570,60 +541,11 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
|
|
||||||
# Dim (Op.SetDim)
|
|
||||||
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
|
|
||||||
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
|
|
||||||
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
|
|
||||||
|
|
||||||
has_dim_func = not ct.FlowSignal.check(dim_func)
|
|
||||||
has_dim_params = not ct.FlowSignal.check(dim_params)
|
|
||||||
has_dim_info = not ct.FlowSignal.check(dim_info)
|
|
||||||
|
|
||||||
# Dimension(s)
|
# Dimension(s)
|
||||||
dim_0 = props['dim_0']
|
dim_0 = props['dim_0']
|
||||||
dim_1 = props['dim_1']
|
dim_1 = props['dim_1']
|
||||||
slice_tuple = props['slice_tuple']
|
slice_tuple = props['slice_tuple']
|
||||||
if has_info and operation is not None:
|
if has_info and operation is not None:
|
||||||
# Set Dimension: Retrieve Array
|
|
||||||
if props['operation'] is FilterOperation.SetDim:
|
|
||||||
new_dim = (
|
|
||||||
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
dim_0 is not None
|
|
||||||
and new_dim is not None
|
|
||||||
and has_dim_info
|
|
||||||
and has_dim_params
|
|
||||||
# Check New Dimension Index Array Sizing
|
|
||||||
and len(dim_info.dims) == 1
|
|
||||||
and dim_info.output.rows == 1
|
|
||||||
and dim_info.output.cols == 1
|
|
||||||
# Check Lack of Params Symbols
|
|
||||||
and not dim_params.symbols
|
|
||||||
# Check Expr Dim | New Dim Compatibility
|
|
||||||
and info.has_idx_discrete(dim_0)
|
|
||||||
and dim_info.has_idx_discrete(new_dim)
|
|
||||||
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
|
|
||||||
):
|
|
||||||
# Retrieve Dimension Coordinate Array
|
|
||||||
## -> It must be strictly compatible.
|
|
||||||
values = dim_func.realize(dim_params, spux.UNITS_SI)
|
|
||||||
|
|
||||||
# Transform Info w/Corrected Dimension
|
|
||||||
## -> The existing dimension will be replaced.
|
|
||||||
new_dim_idx = ct.ArrayFlow(
|
|
||||||
values=values,
|
|
||||||
unit=spux.convert_to_unit_system(
|
|
||||||
dim_info.output.unit, spux.UNITS_SI
|
|
||||||
),
|
|
||||||
).rescale_to_unit(dim_info.output.unit)
|
|
||||||
|
|
||||||
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
|
|
||||||
return operation.transform_info(
|
|
||||||
info, dim_0, dim_1, replaced_dim=replaced_dim
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
|
||||||
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
|
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
|
@ -496,7 +496,7 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||||
if self.info is not None:
|
if self.expr_info is not None:
|
||||||
return [
|
return [
|
||||||
operation.bl_enum_element(i)
|
operation.bl_enum_element(i)
|
||||||
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
|
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
|
||||||
|
|
|
@ -20,8 +20,11 @@ import typing as typ
|
||||||
import bpy
|
import bpy
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
import sympy.physics.quantum as spq
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache, logger
|
from blender_maxwell.utils import bl_cache, logger
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
|
||||||
from .... import contracts as ct
|
from .... import contracts as ct
|
||||||
from .... import sockets
|
from .... import sockets
|
||||||
|
@ -37,37 +40,47 @@ class BinaryOperation(enum.StrEnum):
|
||||||
"""Valid operations for the `OperateMathNode`.
|
"""Valid operations for the `OperateMathNode`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
Add: Addition w/broadcasting.
|
Mul: Scalar multiplication.
|
||||||
Sub: Subtraction w/broadcasting.
|
Div: Scalar division.
|
||||||
Mul: Hadamard-product multiplication.
|
Pow: Scalar exponentiation.
|
||||||
Div: Hadamard-product based division.
|
Add: Elementwise addition.
|
||||||
Pow: Elementwise expontiation.
|
Sub: Elementwise subtraction.
|
||||||
Atan2: Quadrant-respecting arctangent variant.
|
HadamMul: Elementwise multiplication (hadamard product).
|
||||||
VecVecDot: Dot product for vectors.
|
HadamPow: Principled shape-aware exponentiation (hadamard power).
|
||||||
Cross: Cross product.
|
Atan2: Quadrant-respecting 2D arctangent.
|
||||||
MatVecDot: Matrix-Vector dot product.
|
VecVecDot: Dot product for identically shaped vectors w/transpose.
|
||||||
|
Cross: Cross product between identically shaped 3D vectors.
|
||||||
|
VecVecOuter: Vector-vector outer product.
|
||||||
LinSolve: Solve a linear system.
|
LinSolve: Solve a linear system.
|
||||||
LsqSolve: Minimize error of an underdetermined linear system.
|
LsqSolve: Minimize error of an underdetermined linear system.
|
||||||
MatMatDot: Matrix-Matrix dot product.
|
VecMatOuter: Vector-matrix outer product.
|
||||||
|
MatMatDot: Matrix-matrix dot product.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Number | Number
|
# Number | Number
|
||||||
Add = enum.auto()
|
|
||||||
Sub = enum.auto()
|
|
||||||
Mul = enum.auto()
|
Mul = enum.auto()
|
||||||
Div = enum.auto()
|
Div = enum.auto()
|
||||||
Pow = enum.auto()
|
Pow = enum.auto()
|
||||||
|
|
||||||
|
# Elements | Elements
|
||||||
|
Add = enum.auto()
|
||||||
|
Sub = enum.auto()
|
||||||
|
HadamMul = enum.auto()
|
||||||
|
# HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
|
||||||
Atan2 = enum.auto()
|
Atan2 = enum.auto()
|
||||||
|
|
||||||
# Vector | Vector
|
# Vector | Vector
|
||||||
VecVecDot = enum.auto()
|
VecVecDot = enum.auto()
|
||||||
Cross = enum.auto()
|
Cross = enum.auto()
|
||||||
|
VecVecOuter = enum.auto()
|
||||||
|
|
||||||
# Matrix | Vector
|
# Matrix | Vector
|
||||||
MatVecDot = enum.auto()
|
|
||||||
LinSolve = enum.auto()
|
LinSolve = enum.auto()
|
||||||
LsqSolve = enum.auto()
|
LsqSolve = enum.auto()
|
||||||
|
|
||||||
|
# Vector | Matrix
|
||||||
|
VecMatOuter = enum.auto()
|
||||||
|
|
||||||
# Matrix | Matrix
|
# Matrix | Matrix
|
||||||
MatMatDot = enum.auto()
|
MatMatDot = enum.auto()
|
||||||
|
|
||||||
|
@ -79,19 +92,24 @@ class BinaryOperation(enum.StrEnum):
|
||||||
BO = BinaryOperation
|
BO = BinaryOperation
|
||||||
return {
|
return {
|
||||||
# Number | Number
|
# Number | Number
|
||||||
|
BO.Mul: 'ℓ · r',
|
||||||
|
BO.Div: 'ℓ / r',
|
||||||
|
BO.Pow: 'ℓ ^ r',
|
||||||
|
# Elements | Elements
|
||||||
BO.Add: 'ℓ + r',
|
BO.Add: 'ℓ + r',
|
||||||
BO.Sub: 'ℓ - r',
|
BO.Sub: 'ℓ - r',
|
||||||
BO.Mul: 'ℓ ⊙ r', ## Notation for Hadamard Product
|
BO.HadamMul: '𝐋 ⊙ 𝐑',
|
||||||
BO.Div: 'ℓ / r',
|
# BO.HadamPow: '𝐥 ⊙^ 𝐫',
|
||||||
BO.Pow: 'ℓʳ',
|
BO.Atan2: 'atan2(ℓ:x, r:y)',
|
||||||
BO.Atan2: 'atan2(ℓ,r)',
|
|
||||||
# Vector | Vector
|
# Vector | Vector
|
||||||
BO.VecVecDot: '𝐥 · 𝐫',
|
BO.VecVecDot: '𝐥 · 𝐫',
|
||||||
BO.Cross: 'cross(L,R)',
|
BO.Cross: 'cross(𝐥,𝐫)',
|
||||||
|
BO.VecVecOuter: '𝐥 ⊗ 𝐫',
|
||||||
# Matrix | Vector
|
# Matrix | Vector
|
||||||
BO.MatVecDot: '𝐋 · 𝐫',
|
|
||||||
BO.LinSolve: '𝐋 ∖ 𝐫',
|
BO.LinSolve: '𝐋 ∖ 𝐫',
|
||||||
BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
|
BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
|
||||||
|
# Vector | Matrix
|
||||||
|
BO.VecMatOuter: '𝐋 ⊗ 𝐫',
|
||||||
# Matrix | Matrix
|
# Matrix | Matrix
|
||||||
BO.MatMatDot: '𝐋 · 𝐑',
|
BO.MatMatDot: '𝐋 · 𝐑',
|
||||||
}[value]
|
}[value]
|
||||||
|
@ -118,56 +136,104 @@ class BinaryOperation(enum.StrEnum):
|
||||||
"""Deduce valid binary operations from the shapes of the inputs."""
|
"""Deduce valid binary operations from the shapes of the inputs."""
|
||||||
BO = BinaryOperation
|
BO = BinaryOperation
|
||||||
|
|
||||||
ops_number_number = [
|
ops_el_el = [
|
||||||
|
BO.Add,
|
||||||
|
BO.Sub,
|
||||||
|
BO.HadamMul,
|
||||||
|
# BO.HadamPow,
|
||||||
|
]
|
||||||
|
|
||||||
|
outl = info_l.output
|
||||||
|
outr = info_r.output
|
||||||
|
match (outl.shape_len, outr.shape_len):
|
||||||
|
# match (ol.shape_len, info_r.output.shape_len):
|
||||||
|
# Number | *
|
||||||
|
## Number | Number
|
||||||
|
case (0, 0):
|
||||||
|
ops = [
|
||||||
BO.Add,
|
BO.Add,
|
||||||
BO.Sub,
|
BO.Sub,
|
||||||
BO.Mul,
|
BO.Mul,
|
||||||
BO.Div,
|
BO.Div,
|
||||||
BO.Pow,
|
BO.Pow,
|
||||||
BO.Atan2,
|
|
||||||
]
|
]
|
||||||
|
if (
|
||||||
match (info_l.output_shape_len, info_r.output_shape_len):
|
info_l.output.physical_type == spux.PhysicalType.Length
|
||||||
# Number | *
|
and info_l.output.unit == info_r.output.unit
|
||||||
## Number | Number
|
):
|
||||||
case (0, 0):
|
ops += [BO.Atan2]
|
||||||
return ops_number_number
|
return ops
|
||||||
|
|
||||||
## Number | Vector
|
## Number | Vector
|
||||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
|
||||||
case (0, 1):
|
case (0, 1):
|
||||||
return ops_number_number
|
return [BO.Mul] # , BO.HadamPow]
|
||||||
|
|
||||||
## Number | Matrix
|
## Number | Matrix
|
||||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
|
||||||
case (0, 2):
|
case (0, 2):
|
||||||
return ops_number_number
|
return [BO.Mul] # , BO.HadamPow]
|
||||||
|
|
||||||
# Vector | *
|
# Vector | *
|
||||||
## Vector | Number
|
## Vector | Number
|
||||||
case (1, 0):
|
case (1, 0):
|
||||||
return ops_number_number
|
return [BO.Mul] # , BO.HadamPow]
|
||||||
|
|
||||||
## Vector | Number
|
## Vector | Vector
|
||||||
case (1, 1):
|
case (1, 1):
|
||||||
return [*ops_number_number, BO.VecVecDot, BO.Cross]
|
ops = []
|
||||||
|
|
||||||
|
# Vector | Vector
|
||||||
|
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
|
||||||
|
if outl.rows > outl.cols and outr.rows > outr.cols:
|
||||||
|
ops += [BO.VecVecDot, BO.VecVecOuter]
|
||||||
|
|
||||||
|
# Covector | Vector
|
||||||
|
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
|
||||||
|
if outl.rows < outl.cols and outr.rows > outr.cols:
|
||||||
|
ops += [BO.MatMatDot, BO.VecVecOuter]
|
||||||
|
|
||||||
|
# Vector | Covector
|
||||||
|
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
|
||||||
|
## -> These are both the same operation, in this case.
|
||||||
|
if outl.rows > outl.cols and outr.rows < outr.cols:
|
||||||
|
ops += [BO.MatMatDot, BO.VecVecOuter]
|
||||||
|
|
||||||
|
# Covector | Covector
|
||||||
|
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
|
||||||
|
if outl.rows < outl.cols and outr.rows < outr.cols:
|
||||||
|
ops += [BO.VecVecDot, BO.VecVecOuter]
|
||||||
|
|
||||||
|
# Cross Product
|
||||||
|
## -> Enforce that both are 3x1 or 1x3.
|
||||||
|
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
|
||||||
|
if (outl.rows == 3 and outr.rows == 3) or (
|
||||||
|
outl.cols == 3 and outl.cols == 3
|
||||||
|
):
|
||||||
|
ops += [BO.Cross]
|
||||||
|
|
||||||
|
return ops
|
||||||
|
|
||||||
## Vector | Matrix
|
## Vector | Matrix
|
||||||
case (1, 2):
|
case (1, 2):
|
||||||
return []
|
return [BO.VecMatOuter]
|
||||||
|
|
||||||
# Matrix | *
|
# Matrix | *
|
||||||
## Matrix | Number
|
## Matrix | Number
|
||||||
case (2, 0):
|
case (2, 0):
|
||||||
return [*ops_number_number, BO.MatMatDot]
|
return [BO.Mul] # , BO.HadamPow]
|
||||||
|
|
||||||
## Matrix | Vector
|
## Matrix | Vector
|
||||||
case (2, 1):
|
case (2, 1):
|
||||||
return [BO.MatVecDot, BO.LinSolve, BO.LsqSolve]
|
prepend_ops = []
|
||||||
|
|
||||||
|
# Mat-Vec Dot: Enforce RHS Column Vector
|
||||||
|
if outr.rows > outl.cols:
|
||||||
|
prepend_ops += [BO.MatMatDot]
|
||||||
|
|
||||||
|
return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
|
||||||
|
|
||||||
## Matrix | Matrix
|
## Matrix | Matrix
|
||||||
case (2, 2):
|
case (2, 2):
|
||||||
return [*ops_number_number, BO.MatMatDot]
|
return [*ops_el_el, BO.MatMatDot]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -182,34 +248,86 @@ class BinaryOperation(enum.StrEnum):
|
||||||
## TODO: Make this compatible with sp.Matrix inputs
|
## TODO: Make this compatible with sp.Matrix inputs
|
||||||
return {
|
return {
|
||||||
# Number | Number
|
# Number | Number
|
||||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
|
||||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
|
||||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||||
|
# Elements | Elements
|
||||||
|
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||||
|
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||||
|
BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
|
||||||
|
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
|
||||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||||
|
# Vector | Vector
|
||||||
|
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
|
||||||
|
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
|
||||||
|
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
|
||||||
|
# Matrix | Vector
|
||||||
|
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
|
||||||
|
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
|
||||||
|
# Vector | Matrix
|
||||||
|
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
|
||||||
|
# Matrix | Matrix
|
||||||
|
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unit_func(self):
|
||||||
|
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
|
||||||
|
BO = BinaryOperation
|
||||||
|
|
||||||
|
## TODO: Make this compatible with sp.Matrix inputs
|
||||||
|
return {
|
||||||
|
# Number | Number
|
||||||
|
BO.Mul: BO.Mul.sp_func,
|
||||||
|
BO.Div: BO.Div.sp_func,
|
||||||
|
BO.Pow: BO.Pow.sp_func,
|
||||||
|
# Elements | Elements
|
||||||
|
BO.Add: BO.Add.sp_func,
|
||||||
|
BO.Sub: BO.Sub.sp_func,
|
||||||
|
BO.HadamMul: BO.Mul.sp_func,
|
||||||
|
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
|
||||||
|
BO.Atan2: lambda _: spu.radian,
|
||||||
|
# Vector | Vector
|
||||||
|
BO.VecVecDot: BO.Mul.sp_func,
|
||||||
|
BO.Cross: BO.Mul.sp_func,
|
||||||
|
BO.VecVecOuter: BO.Mul.sp_func,
|
||||||
|
# Matrix | Vector
|
||||||
|
## -> A,b in Ax = b have units, and the equality must hold.
|
||||||
|
## -> Therefore, A \ b must have the units [b]/[A].
|
||||||
|
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
|
||||||
|
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
|
||||||
|
# Vector | Matrix
|
||||||
|
BO.VecMatOuter: BO.Mul.sp_func,
|
||||||
|
# Matrix | Matrix
|
||||||
|
BO.MatMatDot: BO.Mul.sp_func,
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def jax_func(self):
|
def jax_func(self):
|
||||||
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
|
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
|
||||||
|
## TODO: Scale the units of one side to the other.
|
||||||
BO = BinaryOperation
|
BO = BinaryOperation
|
||||||
|
|
||||||
return {
|
return {
|
||||||
# Number | Number
|
# Number | Number
|
||||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
|
||||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
|
||||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
# Elements | Elements
|
||||||
|
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||||
|
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||||
|
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
|
||||||
|
# BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
|
||||||
|
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
|
||||||
# Vector | Vector
|
# Vector | Vector
|
||||||
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
|
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
|
||||||
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
|
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
|
||||||
|
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
|
||||||
# Matrix | Vector
|
# Matrix | Vector
|
||||||
BO.MatVecDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
|
||||||
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
|
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
|
||||||
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
|
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
|
||||||
|
# Vector | Matrix
|
||||||
|
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
|
||||||
# Matrix | Matrix
|
# Matrix | Matrix
|
||||||
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||||
}[self]
|
}[self]
|
||||||
|
@ -218,30 +336,12 @@ class BinaryOperation(enum.StrEnum):
|
||||||
# - InfoFlow Transform
|
# - InfoFlow Transform
|
||||||
####################
|
####################
|
||||||
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
||||||
BO = BinaryOperation
|
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
|
||||||
|
return info_l.operate_output(
|
||||||
info_largest = (
|
info_r,
|
||||||
info_l if info_l.output_shape_len > info_l.output_shape_len else info_l
|
lambda a, b: self.sp_func([a, b]),
|
||||||
|
lambda a, b: self.unit_func([a, b]),
|
||||||
)
|
)
|
||||||
info_any = info_largest
|
|
||||||
return {
|
|
||||||
# Number | * or * | Number
|
|
||||||
BO.Add: info_largest,
|
|
||||||
BO.Sub: info_largest,
|
|
||||||
BO.Mul: info_largest,
|
|
||||||
BO.Div: info_largest,
|
|
||||||
BO.Pow: info_largest,
|
|
||||||
BO.Atan2: info_largest,
|
|
||||||
# Vector | Vector
|
|
||||||
BO.VecVecDot: info_any,
|
|
||||||
BO.Cross: info_any,
|
|
||||||
# Matrix | Vector
|
|
||||||
BO.MatVecDot: info_r,
|
|
||||||
BO.LinSolve: info_r,
|
|
||||||
BO.LsqSolve: info_r,
|
|
||||||
# Matrix | Matrix
|
|
||||||
BO.MatMatDot: info_any,
|
|
||||||
}[self]
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -367,9 +467,9 @@ class OperateMathNode(base.MaxwellSimNode):
|
||||||
# Compute Sympy Function
|
# Compute Sympy Function
|
||||||
## -> The operation enum directly provides the appropriate function.
|
## -> The operation enum directly provides the appropriate function.
|
||||||
if has_expr_l_value and has_expr_r_value and operation is not None:
|
if has_expr_l_value and has_expr_r_value and operation is not None:
|
||||||
operation.sp_func([expr_l, expr_r])
|
return operation.sp_func([expr_l, expr_r])
|
||||||
|
|
||||||
return ct.Flowsignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
|
@ -396,7 +496,7 @@ class OperateMathNode(base.MaxwellSimNode):
|
||||||
## -> The operation enum directly provides the appropriate function.
|
## -> The operation enum directly provides the appropriate function.
|
||||||
if has_expr_l and has_expr_r:
|
if has_expr_l and has_expr_r:
|
||||||
return (expr_l | expr_r).compose_within(
|
return (expr_l | expr_r).compose_within(
|
||||||
operation.jax_func,
|
enclosing_func=operation.jax_func,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
|
@ -17,13 +17,13 @@
|
||||||
"""Declares `TransformMathNode`."""
|
"""Declares `TransformMathNode`."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import functools
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import sympy.physics.units as spu
|
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
|
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
@ -51,6 +51,9 @@ class TransformOperation(enum.StrEnum):
|
||||||
# Covariant Transform
|
# Covariant Transform
|
||||||
FreqToVacWL = enum.auto()
|
FreqToVacWL = enum.auto()
|
||||||
VacWLToFreq = enum.auto()
|
VacWLToFreq = enum.auto()
|
||||||
|
ConvertIdxUnit = enum.auto()
|
||||||
|
SetIdxUnit = enum.auto()
|
||||||
|
FirstColToFirstIdx = enum.auto()
|
||||||
|
|
||||||
# Fold
|
# Fold
|
||||||
IntDimToComplex = enum.auto()
|
IntDimToComplex = enum.auto()
|
||||||
|
@ -58,8 +61,8 @@ class TransformOperation(enum.StrEnum):
|
||||||
DimsToMat = enum.auto()
|
DimsToMat = enum.auto()
|
||||||
|
|
||||||
# Fourier
|
# Fourier
|
||||||
FFT1D = enum.auto()
|
FT1D = enum.auto()
|
||||||
InvFFT1D = enum.auto()
|
InvFT1D = enum.auto()
|
||||||
|
|
||||||
# TODO: Affine
|
# TODO: Affine
|
||||||
## TODO
|
## TODO
|
||||||
|
@ -74,17 +77,22 @@ class TransformOperation(enum.StrEnum):
|
||||||
# Covariant Transform
|
# Covariant Transform
|
||||||
TO.FreqToVacWL: '𝑓 → λᵥ',
|
TO.FreqToVacWL: '𝑓 → λᵥ',
|
||||||
TO.VacWLToFreq: 'λᵥ → 𝑓',
|
TO.VacWLToFreq: 'λᵥ → 𝑓',
|
||||||
|
TO.ConvertIdxUnit: 'Convert Dim',
|
||||||
|
TO.SetIdxUnit: 'Set Dim',
|
||||||
|
TO.FirstColToFirstIdx: '1st Col → Dim',
|
||||||
# Fold
|
# Fold
|
||||||
TO.IntDimToComplex: '→ ℂ',
|
TO.IntDimToComplex: '→ ℂ',
|
||||||
TO.DimToVec: '→ Vector',
|
TO.DimToVec: '→ Vector',
|
||||||
TO.DimsToMat: '→ Matrix',
|
TO.DimsToMat: '→ Matrix',
|
||||||
|
## TODO: Vector to new last-dim integer
|
||||||
|
## TODO: Matrix to two last-dim integers
|
||||||
# Fourier
|
# Fourier
|
||||||
TO.FFT1D: 't → 𝑓',
|
TO.FT1D: '→ 𝑓',
|
||||||
TO.InvFFT1D: '𝑓 → t',
|
TO.InvFT1D: '𝑓 →',
|
||||||
}[value]
|
}[value]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_icon(value: typ.Self) -> str:
|
def to_icon(_: typ.Self) -> str:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
||||||
|
@ -98,121 +106,216 @@ class TransformOperation(enum.StrEnum):
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Ops from Shape
|
# - Methods
|
||||||
####################
|
####################
|
||||||
|
@property
|
||||||
|
def num_dim_inputs(self) -> None:
|
||||||
|
"""The number of axes that should be passed as inputs to `func_jax` when evaluating it.
|
||||||
|
|
||||||
|
Especially useful for `ParamFlow`, when deciding whether to pass an integer-axis argument based on a user-selected dimension.
|
||||||
|
"""
|
||||||
|
TO = TransformOperation
|
||||||
|
return {
|
||||||
|
# Covariant Transform
|
||||||
|
TO.FreqToVacWL: 1,
|
||||||
|
TO.VacWLToFreq: 1,
|
||||||
|
TO.ConvertIdxUnit: 1,
|
||||||
|
TO.SetIdxUnit: 1,
|
||||||
|
TO.FirstColToFirstIdx: 0,
|
||||||
|
# Fold
|
||||||
|
TO.IntDimToComplex: 0,
|
||||||
|
TO.DimToVec: 0,
|
||||||
|
TO.DimsToMat: 0,
|
||||||
|
## TODO: Vector to new last-dim integer
|
||||||
|
## TODO: Matrix to two last-dim integers
|
||||||
|
# Fourier
|
||||||
|
TO.FT1D: 1,
|
||||||
|
TO.InvFT1D: 1,
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||||
|
TO = TransformOperation
|
||||||
|
match self:
|
||||||
|
case TO.FreqToVacWL | TO.FT1D:
|
||||||
|
return [
|
||||||
|
dim
|
||||||
|
for dim in info.dims
|
||||||
|
if dim.physical_type is spux.PhysicalType.Freq
|
||||||
|
]
|
||||||
|
|
||||||
|
case TO.VacWLToFreq | TO.InvFT1D:
|
||||||
|
return [
|
||||||
|
dim
|
||||||
|
for dim in info.dims
|
||||||
|
if dim.physical_type is spux.PhysicalType.Length
|
||||||
|
]
|
||||||
|
|
||||||
|
case TO.ConvertIdxUnit | TO.SetIdxUnit:
|
||||||
|
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
|
||||||
|
|
||||||
|
## ColDimToComplex: Implicit Last Dimension
|
||||||
|
## DimToVec: Implicit Last Dimension
|
||||||
|
## DimsToMat: Implicit Last 2 Dimensions
|
||||||
|
|
||||||
|
case TO.FT1D | TO.InvFT1D:
|
||||||
|
# Filter by Axis Uniformity
|
||||||
|
## -> FT requires uniform axis (aka. must be RangeFlow).
|
||||||
|
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
|
||||||
|
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]:
|
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
|
||||||
TO = TransformOperation
|
TO = TransformOperation
|
||||||
operations = []
|
operations = []
|
||||||
|
|
||||||
# Covariant Transform
|
# Covariant Transform
|
||||||
## Freq <-> VacWL
|
## Freq -> VacWL
|
||||||
for dim in info.dims:
|
if TO.FreqToVacWL.valid_dims(info):
|
||||||
if dim.physical_type == spux.PhysicalType.Freq:
|
operations += [TO.FreqToVacWL]
|
||||||
operations.append(TO.FreqToVacWL)
|
|
||||||
|
|
||||||
if dim.physical_type == spux.PhysicalType.Freq:
|
## VacWL -> Freq
|
||||||
operations.append(TO.VacWLToFreq)
|
if TO.VacWLToFreq.valid_dims(info):
|
||||||
|
operations += [TO.VacWLToFreq]
|
||||||
|
|
||||||
|
## Convert Index Unit
|
||||||
|
if TO.ConvertIdxUnit.valid_dims(info):
|
||||||
|
operations += [TO.ConvertIdxUnit]
|
||||||
|
|
||||||
|
if TO.SetIdxUnit.valid_dims(info):
|
||||||
|
operations += [TO.SetIdxUnit]
|
||||||
|
|
||||||
|
## Column to First Index (Array)
|
||||||
|
if (
|
||||||
|
len(info.dims) == 2 # noqa: PLR2004
|
||||||
|
and info.first_dim.mathtype is spux.MathType.Integer
|
||||||
|
and info.last_dim.mathtype is spux.MathType.Integer
|
||||||
|
and info.output.shape_len == 0
|
||||||
|
):
|
||||||
|
operations += [TO.FirstColToFirstIdx]
|
||||||
|
|
||||||
# Fold
|
# Fold
|
||||||
## (Last) Int Dim (=2) to Complex
|
## Last Dim -> Complex
|
||||||
if len(info.dims) >= 1:
|
if (
|
||||||
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
|
info.dims
|
||||||
operations.append(TO.IntDimToComplex)
|
# Output is Int|Rat|Real
|
||||||
|
and (
|
||||||
|
info.output.mathtype
|
||||||
|
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
|
||||||
|
)
|
||||||
|
# Last Axis is Integer of Length 2
|
||||||
|
and info.last_dim.mathtype is spux.MathType.Integer
|
||||||
|
and info.has_idx_labels(info.last_dim)
|
||||||
|
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
|
||||||
|
):
|
||||||
|
operations += [TO.IntDimToComplex]
|
||||||
|
|
||||||
## To Vector
|
## Last Dim -> Vector
|
||||||
if len(info.dims) >= 1:
|
if len(info.dims) >= 1 and info.output.shape_len == 0:
|
||||||
operations.append(TO.DimToVec)
|
operations += [TO.DimToVec]
|
||||||
|
|
||||||
## To Matrix
|
## Last Dim -> Matrix
|
||||||
if len(info.dims) >= 2: # noqa: PLR2004
|
if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
|
||||||
operations.append(TO.DimsToMat)
|
operations += [TO.DimsToMat]
|
||||||
|
|
||||||
# Fourier
|
# Fourier
|
||||||
## 1D Fourier
|
if TO.FT1D.valid_dims(info):
|
||||||
if info.dims:
|
operations += [TO.FT1D]
|
||||||
last_physical_type = info.last_dim.physical_type
|
|
||||||
if last_physical_type == spux.PhysicalType.Time:
|
if TO.InvFT1D.valid_dims(info):
|
||||||
operations.append(TO.FFT1D)
|
operations += [TO.InvFT1D]
|
||||||
if last_physical_type == spux.PhysicalType.Freq:
|
|
||||||
operations.append(TO.InvFFT1D)
|
|
||||||
|
|
||||||
return operations
|
return operations
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Function Properties
|
# - Function Properties
|
||||||
####################
|
####################
|
||||||
@property
|
@functools.cached_property
|
||||||
def sp_func(self):
|
|
||||||
TO = TransformOperation
|
|
||||||
return {
|
|
||||||
# Covariant Transform
|
|
||||||
TO.FreqToVacWL: lambda expr: expr,
|
|
||||||
TO.VacWLToFreq: lambda expr: expr,
|
|
||||||
# Fold
|
|
||||||
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
|
|
||||||
TO.DimToVec: lambda expr: expr,
|
|
||||||
TO.DimsToMat: lambda expr: expr,
|
|
||||||
# Fourier
|
|
||||||
TO.FFT1D: lambda expr: sp.fourier_transform(
|
|
||||||
expr, sim_symbols.t, sim_symbols.freq
|
|
||||||
),
|
|
||||||
TO.InvFFT1D: lambda expr: sp.fourier_transform(
|
|
||||||
expr, sim_symbols.freq, sim_symbols.t
|
|
||||||
),
|
|
||||||
}[self]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def jax_func(self):
|
def jax_func(self):
|
||||||
TO = TransformOperation
|
TO = TransformOperation
|
||||||
return {
|
return {
|
||||||
# Covariant Transform
|
# Covariant Transform
|
||||||
TO.FreqToVacWL: lambda expr: expr,
|
## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
|
||||||
TO.VacWLToFreq: lambda expr: expr,
|
TO.FreqToVacWL: lambda expr, axis: jnp.flip(expr, axis=axis),
|
||||||
|
TO.VacWLToFreq: lambda expr, axis: jnp.flip(expr, axis=axis),
|
||||||
|
TO.ConvertIdxUnit: lambda expr: expr,
|
||||||
|
TO.SetIdxUnit: lambda expr: expr,
|
||||||
|
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
|
||||||
# Fold
|
# Fold
|
||||||
## -> To Complex: With a little imagination, this is a noop :)
|
## -> To Complex: This should generally be a no-op.
|
||||||
## -> **Requires** dims[-1] to be integer-indexed w/length of 2.
|
TO.IntDimToComplex: lambda expr: jnp.squeeze(
|
||||||
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(),
|
expr.view(dtype=jnp.complex64), axis=-1
|
||||||
|
),
|
||||||
TO.DimToVec: lambda expr: expr,
|
TO.DimToVec: lambda expr: expr,
|
||||||
TO.DimsToMat: lambda expr: expr,
|
TO.DimsToMat: lambda expr: expr,
|
||||||
# Fourier
|
# Fourier
|
||||||
TO.FFT1D: lambda expr: jnp.fft(expr),
|
TO.FT1D: lambda expr, axis: jnp.fft(expr, axis=axis),
|
||||||
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
|
TO.InvFT1D: lambda expr, axis: jnp.ifft(expr, axis=axis),
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
def transform_info(
|
def transform_info(
|
||||||
self,
|
self,
|
||||||
info: ct.InfoFlow | None,
|
info: ct.InfoFlow,
|
||||||
data: jtyp.Shaped[jtyp.Array, '...'] | None = None,
|
dim: sim_symbols.SimSymbol | None = None,
|
||||||
|
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
|
||||||
|
new_dim_name: str | None = None,
|
||||||
unit: spux.Unit | None = None,
|
unit: spux.Unit | None = None,
|
||||||
) -> ct.InfoFlow | None:
|
physical_type: spux.PhysicalType | None = None,
|
||||||
|
) -> ct.InfoFlow:
|
||||||
TO = TransformOperation
|
TO = TransformOperation
|
||||||
if not info.dims:
|
|
||||||
return None
|
|
||||||
return {
|
return {
|
||||||
# Covariant Transform
|
# Covariant Transform
|
||||||
TO.FreqToVacWL: lambda: info.replace_dim(
|
TO.FreqToVacWL: lambda: info.replace_dim(
|
||||||
(f_dim := info.last_dim),
|
(f_dim := dim),
|
||||||
[
|
[
|
||||||
sim_symbols.wl(spu.nanometer),
|
sim_symbols.wl(unit),
|
||||||
info.dims[f_dim].rescale(
|
info.dims[f_dim].rescale(
|
||||||
lambda el: sci_constants.vac_speed_of_light / el,
|
lambda el: sci_constants.vac_speed_of_light / el,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
new_unit=spu.nanometer,
|
new_unit=unit,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
TO.VacWLToFreq: lambda: info.replace_dim(
|
TO.VacWLToFreq: lambda: info.replace_dim(
|
||||||
(wl_dim := info.last_dim),
|
(wl_dim := dim),
|
||||||
[
|
[
|
||||||
sim_symbols.freq(spux.THz),
|
sim_symbols.freq(unit),
|
||||||
info.dims[wl_dim].rescale(
|
info.dims[wl_dim].rescale(
|
||||||
lambda el: sci_constants.vac_speed_of_light / el,
|
lambda el: sci_constants.vac_speed_of_light / el,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
new_unit=spux.THz,
|
new_unit=unit,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
TO.ConvertIdxUnit: lambda: info.replace_dim(
|
||||||
|
dim,
|
||||||
|
dim.update(unit=unit),
|
||||||
|
(
|
||||||
|
info.dims[dim].rescale_to_unit(unit)
|
||||||
|
if info.has_idx_discrete(dim)
|
||||||
|
else None ## Continuous -- dim SimSymbol already scaled
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TO.SetIdxUnit: lambda: info.replace_dim(
|
||||||
|
dim,
|
||||||
|
dim.update(
|
||||||
|
sym_name=new_dim_name, physical_type=physical_type, unit=unit
|
||||||
|
),
|
||||||
|
(
|
||||||
|
info.dims[dim].correct_unit(unit)
|
||||||
|
if info.has_idx_discrete(dim)
|
||||||
|
else None ## Continuous -- dim SimSymbol already scaled
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TO.FirstColToFirstIdx: lambda: info.replace_dim(
|
||||||
|
info.first_dim,
|
||||||
|
info.first_dim.update(
|
||||||
|
mathtype=spux.MathType.from_jax_array(data_col),
|
||||||
|
unit=unit,
|
||||||
|
),
|
||||||
|
ct.ArrayFlow(values=data_col, unit=unit),
|
||||||
|
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
|
||||||
# Fold
|
# Fold
|
||||||
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
|
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
|
||||||
mathtype=spux.MathType.Complex
|
mathtype=spux.MathType.Complex
|
||||||
|
@ -220,21 +323,31 @@ class TransformOperation(enum.StrEnum):
|
||||||
TO.DimToVec: lambda: info.fold_last_input(),
|
TO.DimToVec: lambda: info.fold_last_input(),
|
||||||
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
|
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
|
||||||
# Fourier
|
# Fourier
|
||||||
TO.FFT1D: lambda: info.replace_dim(
|
TO.FT1D: lambda: info.replace_dim(
|
||||||
info.last_dim,
|
dim,
|
||||||
[
|
[
|
||||||
sim_symbols.freq(spux.THz),
|
# FT'ed Unit: Reciprocal of the Original Unit
|
||||||
None,
|
dim.update(
|
||||||
|
unit=1 / dim.unit if dim.unit is not None else 1
|
||||||
|
), ## TODO: Okay to not scale interval?
|
||||||
|
# FT'ed Bounds: Reciprocal of the Original Unit
|
||||||
|
info.dims[dim].bound_fourier_transform,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
TO.InvFFT1D: info.replace_dim(
|
TO.InvFT1D: lambda: info.replace_dim(
|
||||||
info.last_dim,
|
info.last_dim,
|
||||||
[
|
[
|
||||||
sim_symbols.t(spu.second),
|
# FT'ed Unit: Reciprocal of the Original Unit
|
||||||
None,
|
dim.update(
|
||||||
|
unit=1 / dim.unit if dim.unit is not None else 1
|
||||||
|
), ## TODO: Okay to not scale interval?
|
||||||
|
# FT'ed Bounds: Reciprocal of the Original Unit
|
||||||
|
## -> Note the midpoint may revert to 0.
|
||||||
|
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
|
||||||
|
info.dims[dim].bound_inv_fourier_transform,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
}.get(self, lambda: info)()
|
}[self]()
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -274,7 +387,6 @@ class TransformMathNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||||
|
|
||||||
info_pending = ct.FlowSignal.check_single(
|
info_pending = ct.FlowSignal.check_single(
|
||||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||||
)
|
)
|
||||||
|
@ -304,45 +416,125 @@ class TransformMathNode(base.MaxwellSimNode):
|
||||||
return [
|
return [
|
||||||
operation.bl_enum_element(i)
|
operation.bl_enum_element(i)
|
||||||
for i, operation in enumerate(
|
for i, operation in enumerate(
|
||||||
TransformOperation.by_element_shape(self.expr_info)
|
TransformOperation.by_info(self.expr_info)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties: Dimension Selection
|
||||||
|
####################
|
||||||
|
active_dim: enum.StrEnum = bl_cache.BLField(
|
||||||
|
enum_cb=lambda self, _: self.search_dims(),
|
||||||
|
cb_depends_on={'operation', 'expr_info'},
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||||
|
if self.expr_info is not None and self.operation is not None:
|
||||||
|
return [
|
||||||
|
(dim.name, dim.name_pretty, dim.name, '', i)
|
||||||
|
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
|
||||||
|
def dim(self) -> sim_symbols.SimSymbol | None:
|
||||||
|
if self.expr_info is not None and self.active_dim is not None:
|
||||||
|
return self.expr_info.dim_by_name(self.active_dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties: New Dimension Properties
|
||||||
|
####################
|
||||||
|
new_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
|
sim_symbols.SimSymbolName.Expr
|
||||||
|
)
|
||||||
|
new_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||||
|
spux.PhysicalType.NonPhysical
|
||||||
|
)
|
||||||
|
active_new_unit: enum.StrEnum = bl_cache.BLField(
|
||||||
|
enum_cb=lambda self, _: self.search_units(),
|
||||||
|
cb_depends_on={'dim', 'new_physical_type'},
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_units(self) -> list[ct.BLEnumElement]:
|
||||||
|
if self.dim is not None:
|
||||||
|
if self.dim.physical_type is not spux.PhysicalType.NonPhysical:
|
||||||
|
unit_name = sp.sstr(self.dim.unit)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
sp.sstr(unit),
|
||||||
|
spux.sp_to_str(unit),
|
||||||
|
sp.sstr(unit),
|
||||||
|
'',
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
for unit in self.dim.physical_type.valid_units
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.dim.unit is not None:
|
||||||
|
unit_name = sp.sstr(self.dim.unit)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
unit_name,
|
||||||
|
spux.sp_to_str(self.dim.unit),
|
||||||
|
unit_name,
|
||||||
|
'',
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if self.new_physical_type is not spux.PhysicalType.NonPhysical:
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
sp.sstr(unit),
|
||||||
|
spux.sp_to_str(unit),
|
||||||
|
sp.sstr(unit),
|
||||||
|
'',
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
for i, unit in enumerate(self.new_physical_type.valid_units)
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'active_new_unit'})
|
||||||
|
def new_unit(self) -> spux.Unit:
|
||||||
|
if self.active_new_unit is not None:
|
||||||
|
return spux.unit_str_to_unit(self.active_new_unit)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
def draw_label(self):
|
def draw_label(self):
|
||||||
if self.operation is not None:
|
if self.operation is not None:
|
||||||
return 'Transform: ' + TransformOperation.to_name(self.operation)
|
return 'T: ' + TransformOperation.to_name(self.operation)
|
||||||
|
|
||||||
return self.bl_label
|
return self.bl_label
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
layout.prop(self, self.blfields['operation'], text='')
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
|
|
||||||
|
if self.operation is not None and self.operation.num_dim_inputs == 1:
|
||||||
|
TO = TransformOperation
|
||||||
|
layout.prop(self, self.blfields['active_dim'], text='')
|
||||||
|
|
||||||
|
if self.operation in [TO.ConvertIdxUnit, TO.SetIdxUnit]:
|
||||||
|
col = layout.column(align=True)
|
||||||
|
if self.operation is TransformOperation.ConvertIdxUnit:
|
||||||
|
col.prop(self, self.blfields['active_new_unit'], text='')
|
||||||
|
|
||||||
|
if self.operation is TransformOperation.SetIdxUnit:
|
||||||
|
col.prop(self, self.blfields['new_physical_type'], text='')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.prop(self, self.blfields['new_name'], text='')
|
||||||
|
row.prop(self, self.blfields['active_new_unit'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Compute: Func / Array
|
# - Compute: Func / Array
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
|
||||||
'Expr',
|
|
||||||
kind=ct.FlowKind.Value,
|
|
||||||
props={'operation'},
|
|
||||||
input_sockets={'Expr'},
|
|
||||||
)
|
|
||||||
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
|
|
||||||
operation = props['operation']
|
|
||||||
expr = input_sockets['Expr']
|
|
||||||
|
|
||||||
has_expr_value = not ct.FlowSignal.check(expr)
|
|
||||||
|
|
||||||
# Compute Sympy Function
|
|
||||||
## -> The operation enum directly provides the appropriate function.
|
|
||||||
if has_expr_value and operation is not None:
|
|
||||||
return operation.sp_func(expr)
|
|
||||||
|
|
||||||
return ct.Flowsignal.FlowPending
|
|
||||||
|
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Func,
|
kind=ct.FlowKind.Func,
|
||||||
|
@ -354,54 +546,103 @@ class TransformMathNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
expr = input_sockets['Expr']
|
lazy_func = input_sockets['Expr']
|
||||||
|
|
||||||
has_expr = not ct.FlowSignal.check(expr)
|
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||||
|
|
||||||
if has_expr and operation is not None:
|
if has_lazy_func and operation is not None:
|
||||||
return expr.compose_within(
|
return lazy_func.compose_within(
|
||||||
operation.jax_func,
|
operation.jax_func,
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKind.Info|Params
|
# - FlowKind.Info
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Info,
|
kind=ct.FlowKind.Info,
|
||||||
props={'operation'},
|
props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
input_socket_kinds={
|
||||||
|
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
def compute_info(
|
def compute_info(
|
||||||
self, props: dict, input_sockets: dict
|
self, props: dict, input_sockets: dict
|
||||||
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
|
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
|
||||||
operation = props['operation']
|
operation = props['operation']
|
||||||
info = input_sockets['Expr']
|
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||||
|
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
|
|
||||||
|
dim = props['dim']
|
||||||
|
new_name = props['new_name']
|
||||||
|
new_unit = props['new_unit']
|
||||||
|
new_physical_type = props['new_physical_type']
|
||||||
if has_info and operation is not None:
|
if has_info and operation is not None:
|
||||||
transformed_info = operation.transform_info(info)
|
# First Column to First Index
|
||||||
|
## -> We have to evaluate the lazy function at this point.
|
||||||
if transformed_info is None:
|
## -> It's the only way to get at the column data.
|
||||||
return ct.FlowSignal.FlowPending
|
if operation is TransformOperation.FirstColToFirstIdx:
|
||||||
return transformed_info
|
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||||
|
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||||
|
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||||
|
has_params = not ct.FlowSignal.check(lazy_func)
|
||||||
|
|
||||||
|
if has_lazy_func and has_params and not params.symbols:
|
||||||
|
data = lazy_func.realize(params)
|
||||||
|
if data.shape is not None and len(data.shape) == 2:
|
||||||
|
data_col = data[:, 0]
|
||||||
|
return operation.transform_info(info, data_col=data_col)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
# Check Not-Yet-Updated Dimension
|
||||||
|
## - Operation changes before dimensions.
|
||||||
|
## - If InfoFlow is requested in this interim, big problem.
|
||||||
|
if dim is None and operation.num_dim_inputs > 0:
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
return operation.transform_info(
|
||||||
|
info,
|
||||||
|
dim=dim,
|
||||||
|
new_dim_name=new_name,
|
||||||
|
unit=new_unit,
|
||||||
|
physical_type=new_physical_type,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Params
|
||||||
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Params,
|
kind=ct.FlowKind.Params,
|
||||||
|
props={'operation', 'dim'},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={'Expr': ct.FlowKind.Params},
|
input_socket_kinds={'Expr': {ct.FlowKind.Params, ct.FlowKind.Info}},
|
||||||
)
|
)
|
||||||
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
|
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||||
if has_params:
|
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||||
return input_sockets['Expr']
|
|
||||||
|
has_info = not ct.FlowSignal.check(info)
|
||||||
|
has_params = not ct.FlowSignal.check(params)
|
||||||
|
|
||||||
|
operation = props['operation']
|
||||||
|
dim = props['dim']
|
||||||
|
if has_info and has_params and operation is not None:
|
||||||
|
# Axis Required: Insert by-Dimension
|
||||||
|
## -> Some transformations ex. FT require setting an axis.
|
||||||
|
## -> The user selects which dimension the op should be done along.
|
||||||
|
## -> This dimension is converted to an axis integer.
|
||||||
|
## -> Finally, we pass the argument via params.
|
||||||
|
if operation.num_dim_inputs == 1:
|
||||||
|
axis = info.dim_axis(dim) if dim is not None else None
|
||||||
|
return params.compose_within(enclosing_func_args=[axis])
|
||||||
|
|
||||||
|
return params
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import bpy
|
||||||
import jaxtyping as jtyp
|
import jaxtyping as jtyp
|
||||||
import matplotlib.axis as mpl_ax
|
import matplotlib.axis as mpl_ax
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
|
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
|
||||||
from blender_maxwell.utils import extra_sympy_units as spux
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
@ -74,8 +75,33 @@ class VizMode(enum.StrEnum):
|
||||||
SqueezedHeatmap2D = enum.auto()
|
SqueezedHeatmap2D = enum.auto()
|
||||||
Heatmap3D = enum.auto()
|
Heatmap3D = enum.auto()
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - UI
|
||||||
|
####################
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
def to_name(value: typ.Self) -> str:
|
||||||
|
return {
|
||||||
|
VizMode.BoxPlot1D: 'Box Plot',
|
||||||
|
VizMode.Curve2D: 'Curve',
|
||||||
|
VizMode.Points2D: 'Points',
|
||||||
|
VizMode.Bar: 'Bar',
|
||||||
|
VizMode.Curves2D: 'Curves',
|
||||||
|
VizMode.FilledCurves2D: 'Filled Curves',
|
||||||
|
VizMode.Heatmap2D: 'Heatmap',
|
||||||
|
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
|
||||||
|
VizMode.Heatmap3D: 'Heatmap (3D)',
|
||||||
|
}[value]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_icon(value: typ.Self) -> ct.BLIcon:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Validity
|
||||||
|
####################
|
||||||
|
@staticmethod
|
||||||
|
def by_info(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||||
|
"""Given the input `InfoFlow`, deduce which visualization modes are valid to use with the described data."""
|
||||||
Z = spux.MathType.Integer
|
Z = spux.MathType.Integer
|
||||||
R = spux.MathType.Real
|
R = spux.MathType.Real
|
||||||
VM = VizMode
|
VM = VizMode
|
||||||
|
@ -102,15 +128,18 @@ class VizMode(enum.StrEnum):
|
||||||
],
|
],
|
||||||
}.get(
|
}.get(
|
||||||
(
|
(
|
||||||
tuple([dim.mathtype for dim in info.dims.values()]),
|
tuple([dim.mathtype for dim in info.dims]),
|
||||||
(info.output.rows, info.output.cols, info.output.mathtype),
|
(info.output.rows, info.output.cols, info.output.mathtype),
|
||||||
),
|
),
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
####################
|
||||||
def to_plotter(
|
# - Properties
|
||||||
value: typ.Self,
|
####################
|
||||||
|
@property
|
||||||
|
def mpl_plotter(
|
||||||
|
self,
|
||||||
) -> typ.Callable[
|
) -> typ.Callable[
|
||||||
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
|
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
|
||||||
]:
|
]:
|
||||||
|
@ -124,25 +153,7 @@ class VizMode(enum.StrEnum):
|
||||||
VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
|
VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
|
||||||
# NO PLOTTER: VizMode.SqueezedHeatmap2D
|
# NO PLOTTER: VizMode.SqueezedHeatmap2D
|
||||||
# NO PLOTTER: VizMode.Heatmap3D
|
# NO PLOTTER: VizMode.Heatmap3D
|
||||||
}[value]
|
}[self]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_name(value: typ.Self) -> str:
|
|
||||||
return {
|
|
||||||
VizMode.BoxPlot1D: 'Box Plot',
|
|
||||||
VizMode.Curve2D: 'Curve',
|
|
||||||
VizMode.Points2D: 'Points',
|
|
||||||
VizMode.Bar: 'Bar',
|
|
||||||
VizMode.Curves2D: 'Curves',
|
|
||||||
VizMode.FilledCurves2D: 'Filled Curves',
|
|
||||||
VizMode.Heatmap2D: 'Heatmap',
|
|
||||||
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
|
|
||||||
VizMode.Heatmap3D: 'Heatmap (3D)',
|
|
||||||
}[value]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_icon(value: typ.Self) -> ct.BLIcon:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
class VizTarget(enum.StrEnum):
|
class VizTarget(enum.StrEnum):
|
||||||
|
@ -181,6 +192,10 @@ class VizTarget(enum.StrEnum):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
sym_x_um = sim_symbols.space_x(spu.um)
|
||||||
|
x_um = sym_x_um.sp_symbol
|
||||||
|
|
||||||
|
|
||||||
class VizNode(base.MaxwellSimNode):
|
class VizNode(base.MaxwellSimNode):
|
||||||
"""Node for visualizing simulation data, by querying its monitors.
|
"""Node for visualizing simulation data, by querying its monitors.
|
||||||
|
|
||||||
|
@ -188,7 +203,6 @@ class VizNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
colormap: Colormap to apply to 0..1 output.
|
colormap: Colormap to apply to 0..1 output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_type = ct.NodeType.Viz
|
node_type = ct.NodeType.Viz
|
||||||
|
@ -201,8 +215,8 @@ class VizNode(base.MaxwellSimNode):
|
||||||
input_sockets: typ.ClassVar = {
|
input_sockets: typ.ClassVar = {
|
||||||
'Expr': sockets.ExprSocketDef(
|
'Expr': sockets.ExprSocketDef(
|
||||||
active_kind=ct.FlowKind.Func,
|
active_kind=ct.FlowKind.Func,
|
||||||
default_symbols=[sim_symbols.x],
|
default_symbols=[sym_x_um],
|
||||||
default_value=2 * sim_symbols.x.sp_symbol,
|
default_value=sp.exp(-(x_um**2)),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
output_sockets: typ.ClassVar = {
|
output_sockets: typ.ClassVar = {
|
||||||
|
@ -240,16 +254,21 @@ class VizNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
viz_mode: enum.StrEnum = bl_cache.BLField(
|
viz_mode: VizMode = bl_cache.BLField(
|
||||||
enum_cb=lambda self, _: self.search_viz_modes(),
|
enum_cb=lambda self, _: self.search_viz_modes(),
|
||||||
cb_depends_on={'expr_info'},
|
cb_depends_on={'expr_info'},
|
||||||
)
|
)
|
||||||
viz_target: enum.StrEnum = bl_cache.BLField(
|
viz_target: VizTarget = bl_cache.BLField(
|
||||||
enum_cb=lambda self, _: self.search_targets(),
|
enum_cb=lambda self, _: self.search_targets(),
|
||||||
cb_depends_on={'viz_mode'},
|
cb_depends_on={'viz_mode'},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mode-Dependent Properties
|
# Plot
|
||||||
|
plot_width: float = bl_cache.BLField(6.0, abs_min=0.1)
|
||||||
|
plot_height: float = bl_cache.BLField(3.0, abs_min=0.1)
|
||||||
|
plot_dpi: int = bl_cache.BLField(150, abs_min=25)
|
||||||
|
|
||||||
|
# Pixels
|
||||||
colormap: image_ops.Colormap = bl_cache.BLField(
|
colormap: image_ops.Colormap = bl_cache.BLField(
|
||||||
image_ops.Colormap.Viridis,
|
image_ops.Colormap.Viridis,
|
||||||
)
|
)
|
||||||
|
@ -267,7 +286,7 @@ class VizNode(base.MaxwellSimNode):
|
||||||
VizMode.to_icon(viz_mode),
|
VizMode.to_icon(viz_mode),
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info))
|
for i, viz_mode in enumerate(VizMode.by_info(self.expr_info))
|
||||||
]
|
]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
@ -300,9 +319,22 @@ class VizNode(base.MaxwellSimNode):
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
||||||
col.prop(self, self.blfields['viz_mode'], text='')
|
col.prop(self, self.blfields['viz_mode'], text='')
|
||||||
col.prop(self, self.blfields['viz_target'], text='')
|
col.prop(self, self.blfields['viz_target'], text='')
|
||||||
|
|
||||||
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
|
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
|
||||||
col.prop(self, self.blfields['colormap'], text='')
|
col.prop(self, self.blfields['colormap'], text='')
|
||||||
|
|
||||||
|
if self.viz_target is VizTarget.Plot2D:
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Width/Height/DPI')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.prop(self, self.blfields['plot_width'], text='')
|
||||||
|
row.prop(self, self.blfields['plot_height'], text='')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
col.prop(self, self.blfields['plot_dpi'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Events
|
# - Events
|
||||||
####################
|
####################
|
||||||
|
@ -320,16 +352,16 @@ class VizNode(base.MaxwellSimNode):
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
has_params = not ct.FlowSignal.check(params)
|
has_params = not ct.FlowSignal.check(params)
|
||||||
|
|
||||||
# Provide Sockets for Symbol Realization
|
# Declare Loose Sockets that Realize Symbols
|
||||||
## -> This happens if Params contains not-yet-realized symbols.
|
## -> This happens if Params contains not-yet-realized symbols.
|
||||||
if has_info and has_params and params.symbols:
|
if has_info and has_params and params.symbols:
|
||||||
if set(self.loose_input_sockets) != {
|
if set(self.loose_input_sockets) != {
|
||||||
dim.name for dim in params.symbols if dim in info.dims
|
sym.name for sym in params.symbols if sym in info.dims
|
||||||
}:
|
}:
|
||||||
self.loose_input_sockets = {
|
self.loose_input_sockets = {
|
||||||
dim_name: sockets.ExprSocketDef(**expr_info)
|
dim_name: sockets.ExprSocketDef(**expr_info)
|
||||||
for dim_name, expr_info in params.sym_expr_infos(
|
for dim_name, expr_info in params.sym_expr_infos(
|
||||||
info, use_range=True
|
use_range=True
|
||||||
).items()
|
).items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +375,14 @@ class VizNode(base.MaxwellSimNode):
|
||||||
'Preview',
|
'Preview',
|
||||||
kind=ct.FlowKind.Value,
|
kind=ct.FlowKind.Value,
|
||||||
# Loaded
|
# Loaded
|
||||||
props={'viz_mode', 'viz_target', 'colormap'},
|
props={
|
||||||
|
'viz_mode',
|
||||||
|
'viz_target',
|
||||||
|
'colormap',
|
||||||
|
'plot_width',
|
||||||
|
'plot_height',
|
||||||
|
'plot_dpi',
|
||||||
|
},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||||
|
@ -359,7 +398,14 @@ class VizNode(base.MaxwellSimNode):
|
||||||
#####################
|
#####################
|
||||||
@events.on_show_plot(
|
@events.on_show_plot(
|
||||||
managed_objs={'plot'},
|
managed_objs={'plot'},
|
||||||
props={'viz_mode', 'viz_target', 'colormap'},
|
props={
|
||||||
|
'viz_mode',
|
||||||
|
'viz_target',
|
||||||
|
'colormap',
|
||||||
|
'plot_width',
|
||||||
|
'plot_height',
|
||||||
|
'plot_dpi',
|
||||||
|
},
|
||||||
input_sockets={'Expr'},
|
input_sockets={'Expr'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||||
|
@ -370,7 +416,6 @@ class VizNode(base.MaxwellSimNode):
|
||||||
def on_show_plot(
|
def on_show_plot(
|
||||||
self, managed_objs, props, input_sockets, loose_input_sockets
|
self, managed_objs, props, input_sockets, loose_input_sockets
|
||||||
) -> None:
|
) -> None:
|
||||||
# Retrieve Inputs
|
|
||||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||||
|
@ -378,43 +423,51 @@ class VizNode(base.MaxwellSimNode):
|
||||||
has_info = not ct.FlowSignal.check(info)
|
has_info = not ct.FlowSignal.check(info)
|
||||||
has_params = not ct.FlowSignal.check(params)
|
has_params = not ct.FlowSignal.check(params)
|
||||||
|
|
||||||
if (
|
plot = managed_objs['plot']
|
||||||
not has_info
|
viz_mode = props['viz_mode']
|
||||||
or not has_params
|
viz_target = props['viz_target']
|
||||||
or props['viz_mode'] is None
|
if has_info and has_params and viz_mode is not None and viz_target is not None:
|
||||||
or props['viz_target'] is None
|
# Realize Data w/Realized Symbols
|
||||||
):
|
## -> The loose input socket values are user-selected symbol values.
|
||||||
return
|
## -> These expressions are used to realize the lazy data.
|
||||||
|
## -> `.realize()` ensures all ex. units are correctly conformed.
|
||||||
# Compute Ranges for Symbols from Loose Sockets
|
realized_syms = {
|
||||||
## -> In a quite nice turn of events, all this is cached lookups.
|
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
|
||||||
## -> ...Unless something changed, in which case, well. It changed.
|
|
||||||
symbol_array_values = {
|
|
||||||
sim_syms: (
|
|
||||||
loose_input_sockets[sim_syms]
|
|
||||||
.rescale_to_unit(sim_syms.unit)
|
|
||||||
.realize_array
|
|
||||||
)
|
|
||||||
for sim_syms in params.sorted_symbols
|
|
||||||
}
|
}
|
||||||
data = lazy_func.realize(params, symbol_values=symbol_array_values)
|
output_data = lazy_func.realize(params, symbol_values=realized_syms)
|
||||||
|
|
||||||
# Replace InfoFlow Indices w/Realized Symbolic Ranges
|
data = {
|
||||||
## -> This ensures correct axis scaling.
|
dim: (
|
||||||
if params.symbols:
|
realized_syms[dim].values
|
||||||
info = info.replace_dims(symbol_array_values)
|
if dim in realized_syms
|
||||||
|
else info.dims[dim]
|
||||||
|
)
|
||||||
|
for dim in info.dims
|
||||||
|
} | {info.output: output_data}
|
||||||
|
|
||||||
match props['viz_target']:
|
# Match Viz Type & Perform Visualization
|
||||||
|
## -> Viz Target determines how to plot.
|
||||||
|
## -> Viz Mode may help select a particular plotting method.
|
||||||
|
## -> Other parameters may be uses, depending on context.
|
||||||
|
match viz_target:
|
||||||
case VizTarget.Plot2D:
|
case VizTarget.Plot2D:
|
||||||
managed_objs['plot'].mpl_plot_to_image(
|
plot_width = props['plot_width']
|
||||||
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
|
plot_height = props['plot_height']
|
||||||
|
plot_dpi = props['plot_dpi']
|
||||||
|
plot.mpl_plot_to_image(
|
||||||
|
lambda ax: viz_mode.mpl_plotter(data, ax),
|
||||||
|
width_inches=plot_width,
|
||||||
|
height_inches=plot_height,
|
||||||
|
dpi=plot_dpi,
|
||||||
bl_select=True,
|
bl_select=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
case VizTarget.Pixels:
|
case VizTarget.Pixels:
|
||||||
managed_objs['plot'].map_2d_to_image(
|
colormap = props['colormap']
|
||||||
|
if colormap is not None:
|
||||||
|
plot.map_2d_to_image(
|
||||||
data,
|
data,
|
||||||
colormap=props['colormap'],
|
colormap=colormap,
|
||||||
bl_select=True,
|
bl_select=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -610,7 +610,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
## -> Anyone needing results will need to wait on preinit().
|
## -> Anyone needing results will need to wait on preinit().
|
||||||
return ct.FlowSignal.FlowInitializing
|
return ct.FlowSignal.FlowInitializing
|
||||||
|
|
||||||
if optional:
|
# if optional:
|
||||||
return ct.FlowSignal.NoFlow
|
return ct.FlowSignal.NoFlow
|
||||||
|
|
||||||
msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket'
|
msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket'
|
||||||
|
@ -659,11 +659,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
return output_socket_methods[0](self)
|
return output_socket_methods[0](self)
|
||||||
|
|
||||||
# Auxiliary Fallbacks
|
# Auxiliary Fallbacks
|
||||||
if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
|
|
||||||
return ct.FlowSignal.NoFlow
|
return ct.FlowSignal.NoFlow
|
||||||
|
# if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
|
||||||
|
# return ct.FlowSignal.NoFlow
|
||||||
|
|
||||||
msg = f'No output method for ({output_socket_name}, {kind})'
|
# msg = f'No output method for ({output_socket_name}, {kind})'
|
||||||
raise ValueError(msg)
|
# raise ValueError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Event Trigger
|
# - Event Trigger
|
||||||
|
|
|
@ -30,6 +30,7 @@ class ExprConstantNode(base.MaxwellSimNode):
|
||||||
input_sockets: typ.ClassVar = {
|
input_sockets: typ.ClassVar = {
|
||||||
'Expr': sockets.ExprSocketDef(
|
'Expr': sockets.ExprSocketDef(
|
||||||
active_kind=ct.FlowKind.Func,
|
active_kind=ct.FlowKind.Func,
|
||||||
|
show_name_selector=True,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
output_sockets: typ.ClassVar = {
|
output_sockets: typ.ClassVar = {
|
||||||
|
|
|
@ -17,8 +17,11 @@
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
import sympy as sp
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
from blender_maxwell.utils import bl_cache, sci_constants
|
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
|
||||||
from .... import contracts as ct
|
from .... import contracts as ct
|
||||||
from .... import sockets
|
from .... import sockets
|
||||||
|
@ -30,15 +33,19 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
bl_label = 'Scientific Constant'
|
bl_label = 'Scientific Constant'
|
||||||
|
|
||||||
output_sockets: typ.ClassVar = {
|
output_sockets: typ.ClassVar = {
|
||||||
'Value': sockets.ExprSocketDef(),
|
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
sci_constant: str = bl_cache.BLField(
|
use_symbol: bool = bl_cache.BLField(False)
|
||||||
|
|
||||||
|
sci_constant_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
|
sim_symbols.SimSymbolName.LowerU
|
||||||
|
)
|
||||||
|
sci_constant_str: str = bl_cache.BLField(
|
||||||
'',
|
'',
|
||||||
prop_ui=True,
|
|
||||||
str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text),
|
str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,27 +59,139 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
if edit_text.lower() in name.lower()
|
if edit_text.lower() in name.lower()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
|
||||||
|
def sci_constant(self) -> spux.SympyExpr | None:
|
||||||
|
"""Retrieve the expression for the scientific constant."""
|
||||||
|
return sci_constants.SCI_CONSTANTS.get(self.sci_constant_str)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
|
||||||
|
def sci_constant_info(self) -> spux.SympyExpr | None:
|
||||||
|
"""Retrieve the information for the selected scientific constant."""
|
||||||
|
return sci_constants.SCI_CONSTANTS_INFO.get(self.sci_constant_str)
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(
|
||||||
|
depends_on={'sci_constant', 'sci_constant_info', 'sci_constant_name'}
|
||||||
|
)
|
||||||
|
def sci_constant_sym(self) -> spux.SympyExpr | None:
|
||||||
|
"""Retrieve a symbol for the scientific constant."""
|
||||||
|
if self.sci_constant is not None and self.sci_constant_info is not None:
|
||||||
|
unit = self.sci_constant_info['units']
|
||||||
|
return sim_symbols.SimSymbol(
|
||||||
|
sym_name=self.sci_constant_name,
|
||||||
|
mathtype=spux.MathType.from_expr(self.sci_constant),
|
||||||
|
# physical_type= ## TODO: Formalize unit w/o physical_type
|
||||||
|
unit=unit,
|
||||||
|
is_constant=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
col.prop(self, self.blfields['sci_constant'], text='')
|
col.prop(self, self.blfields['sci_constant_str'], text='')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Assign Symbol')
|
||||||
|
col.prop(self, self.blfields['sci_constant_name'], text='')
|
||||||
|
col.prop(self, self.blfields['use_symbol'], text='Use Symbol', toggle=True)
|
||||||
|
|
||||||
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
if self.sci_constant:
|
box = col.box()
|
||||||
col.label(
|
split = box.split(factor=0.25, align=True)
|
||||||
text=f'Units: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["units"]}'
|
|
||||||
)
|
# Left: Units
|
||||||
col.label(
|
_col = split.column(align=True)
|
||||||
text=f'Uncertainty: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["uncertainty"]}'
|
row = _col.row(align=True)
|
||||||
)
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text='Src')
|
||||||
|
|
||||||
|
if self.sci_constant_info:
|
||||||
|
row = _col.row(align=True)
|
||||||
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text='Unit')
|
||||||
|
|
||||||
|
row = _col.row(align=True)
|
||||||
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text='Err')
|
||||||
|
|
||||||
|
# Right: Values
|
||||||
|
_col = split.column(align=True)
|
||||||
|
row = _col.row(align=True)
|
||||||
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text='CODATA2018')
|
||||||
|
|
||||||
|
if self.sci_constant_info:
|
||||||
|
row = _col.row(align=True)
|
||||||
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text=f'{self.sci_constant_info["units"]}')
|
||||||
|
|
||||||
|
row = _col.row(align=True)
|
||||||
|
# row.alignment = 'CENTER'
|
||||||
|
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output
|
# - Output
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket('Value', props={'sci_constant'})
|
@events.computes_output_socket(
|
||||||
def compute_value(self, props: dict) -> typ.Any:
|
'Expr',
|
||||||
return sci_constants.SCI_CONSTANTS[props['sci_constant']]
|
props={'use_symbol', 'sci_constant', 'sci_constant_sym'},
|
||||||
|
)
|
||||||
|
def compute_value(self, props) -> typ.Any:
|
||||||
|
sci_constant = props['sci_constant']
|
||||||
|
sci_constant_sym = props['sci_constant_sym']
|
||||||
|
|
||||||
|
if props['use_symbol'] and sci_constant_sym is not None:
|
||||||
|
return sci_constant_sym.sp_symbol
|
||||||
|
|
||||||
|
if sci_constant is not None:
|
||||||
|
return sci_constant
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Expr',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
props={'sci_constant', 'sci_constant_sym'},
|
||||||
|
)
|
||||||
|
def compute_lazy_func(self, props) -> typ.Any:
|
||||||
|
sci_constant = props['sci_constant']
|
||||||
|
sci_constant_sym = props['sci_constant_sym']
|
||||||
|
|
||||||
|
if sci_constant is not None:
|
||||||
|
return ct.FuncFlow(
|
||||||
|
func=sp.lambdify(
|
||||||
|
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
|
||||||
|
),
|
||||||
|
func_args=[sci_constant_sym],
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Expr',
|
||||||
|
kind=ct.FlowKind.Info,
|
||||||
|
props={'sci_constant_sym'},
|
||||||
|
)
|
||||||
|
def compute_info(self, props: dict) -> typ.Any:
|
||||||
|
sci_constant_sym = props['sci_constant_sym']
|
||||||
|
|
||||||
|
if sci_constant_sym is not None:
|
||||||
|
return ct.InfoFlow(output=sci_constant_sym)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Expr',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
props={'sci_constant'},
|
||||||
|
)
|
||||||
|
def compute_params(self, props: dict) -> typ.Any:
|
||||||
|
sci_constant = props['sci_constant']
|
||||||
|
|
||||||
|
if sci_constant is not None:
|
||||||
|
return ct.ParamsFlow(func_args=[sci_constant])
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -95,62 +95,35 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - Info Guides
|
# - Info Guides
|
||||||
####################
|
####################
|
||||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
|
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
sim_symbols.SimSymbolName.Data
|
||||||
|
)
|
||||||
|
output_mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
|
||||||
output_physical_type: spux.PhysicalType = bl_cache.BLField(
|
output_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||||
spux.PhysicalType.NonPhysical
|
spux.PhysicalType.NonPhysical
|
||||||
)
|
)
|
||||||
output_unit: enum.StrEnum = bl_cache.BLField(
|
output_unit: enum.StrEnum = bl_cache.BLField(
|
||||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
enum_cb=lambda self, _: self.search_units(self.output_physical_type),
|
||||||
cb_depends_on={'output_physical_type'},
|
cb_depends_on={'output_physical_type'},
|
||||||
)
|
)
|
||||||
|
|
||||||
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.LowerA
|
sim_symbols.SimSymbolName.LowerA
|
||||||
)
|
)
|
||||||
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
|
||||||
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
|
|
||||||
spux.PhysicalType.NonPhysical
|
|
||||||
)
|
|
||||||
dim_0_unit: enum.StrEnum = bl_cache.BLField(
|
|
||||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
|
||||||
cb_depends_on={'dim_0_physical_type'},
|
|
||||||
)
|
|
||||||
|
|
||||||
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.LowerB
|
sim_symbols.SimSymbolName.LowerB
|
||||||
)
|
)
|
||||||
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
|
||||||
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
|
|
||||||
spux.PhysicalType.NonPhysical
|
|
||||||
)
|
|
||||||
dim_1_unit: enum.StrEnum = bl_cache.BLField(
|
|
||||||
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
|
|
||||||
cb_depends_on={'dim_1_physical_type'},
|
|
||||||
)
|
|
||||||
|
|
||||||
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.LowerC
|
sim_symbols.SimSymbolName.LowerC
|
||||||
)
|
)
|
||||||
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
|
||||||
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
|
|
||||||
spux.PhysicalType.NonPhysical
|
|
||||||
)
|
|
||||||
dim_2_unit: enum.StrEnum = bl_cache.BLField(
|
|
||||||
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
|
|
||||||
cb_depends_on={'dim_2_physical_type'},
|
|
||||||
)
|
|
||||||
|
|
||||||
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.LowerD
|
sim_symbols.SimSymbolName.LowerD
|
||||||
)
|
)
|
||||||
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
dim_4_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
|
sim_symbols.SimSymbolName.LowerE
|
||||||
spux.PhysicalType.NonPhysical
|
|
||||||
)
|
)
|
||||||
dim_3_unit: enum.StrEnum = bl_cache.BLField(
|
dim_5_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
|
sim_symbols.SimSymbolName.LowerF
|
||||||
cb_depends_on={'dim_3_physical_type'},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
|
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
|
||||||
|
@ -161,19 +134,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def dim(self, i: int):
|
|
||||||
dim_name = getattr(self, f'dim_{i}_name')
|
|
||||||
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
|
|
||||||
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
|
|
||||||
dim_unit = getattr(self, f'dim_{i}_unit')
|
|
||||||
|
|
||||||
return sim_symbols.SimSymbol(
|
|
||||||
sym_name=dim_name,
|
|
||||||
mathtype=dim_mathtype,
|
|
||||||
physical_type=dim_physical_type,
|
|
||||||
unit=spux.unit_str_to_unit(dim_unit),
|
|
||||||
)
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
|
@ -202,19 +162,21 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
"""Draw loaded properties."""
|
"""Draw loaded properties."""
|
||||||
for i in range(len(self.expr_info.dims)):
|
|
||||||
col = layout.column(align=True)
|
col = layout.column(align=True)
|
||||||
|
if self.expr_info is not None:
|
||||||
|
for i in range(len(self.expr_info.dims)):
|
||||||
|
col.prop(self, self.blfields[f'dim_{i}_name'], text=f'Dim {i}')
|
||||||
|
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.alignment = 'CENTER'
|
row.alignment = 'CENTER'
|
||||||
row.label(text=f'Load Dim {i}')
|
row.label(text='Output')
|
||||||
|
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
|
row.prop(self, self.blfields['output_name'], text='')
|
||||||
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
|
row.prop(self, self.blfields['output_mathtype'], text='')
|
||||||
|
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
|
row.prop(self, self.blfields['output_physical_type'], text='')
|
||||||
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
|
row.prop(self, self.blfields['output_unit'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKind.Array|Func
|
# - FlowKind.Array|Func
|
||||||
|
@ -271,7 +233,8 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
'Expr',
|
'Expr',
|
||||||
kind=ct.FlowKind.Info,
|
kind=ct.FlowKind.Info,
|
||||||
# Loaded
|
# Loaded
|
||||||
props={'output_name', 'output_physical_type', 'output_unit'},
|
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
|
||||||
|
| {f'dim_{i}_name' for i in range(6)},
|
||||||
output_sockets={'Expr'},
|
output_sockets={'Expr'},
|
||||||
output_socket_kinds={'Expr': ct.FlowKind.Func},
|
output_socket_kinds={'Expr': ct.FlowKind.Func},
|
||||||
)
|
)
|
||||||
|
@ -285,32 +248,31 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
||||||
A completely empty `ParamsFlow`, ready to be composed.
|
A completely empty `ParamsFlow`, ready to be composed.
|
||||||
"""
|
"""
|
||||||
expr = output_sockets['Expr']
|
expr = output_sockets['Expr']
|
||||||
|
|
||||||
has_expr_func = not ct.FlowSignal.check(expr)
|
has_expr_func = not ct.FlowSignal.check(expr)
|
||||||
|
|
||||||
if has_expr_func:
|
if has_expr_func:
|
||||||
data = expr.func_jax()
|
data = expr.func_jax()
|
||||||
|
|
||||||
# Deduce Dimensionality
|
# Deduce Dimension Symbols
|
||||||
_shape = data.shape
|
## -> They are all chronically integer indices.
|
||||||
shape = _shape if _shape is not None else ()
|
## -> The FilterNode can be used to "steal" an index from the data.
|
||||||
dim_syms = [self.dim(i) for i in range(len(shape))]
|
shape = data.shape if data.shape is not None else ()
|
||||||
|
|
||||||
# Return InfoFlow
|
|
||||||
return ct.InfoFlow(
|
|
||||||
dims = {
|
dims = {
|
||||||
dim_sym: ct.RangeFlow(
|
sim_symbols.idx(None).update(
|
||||||
start=sp.S(0),
|
sym_name=props[f'dim_{i}_name'],
|
||||||
stop=sp.S(shape[i] - 1),
|
interval_finite_z=(0, elements),
|
||||||
steps=shape[i],
|
interval_inf=(False, False),
|
||||||
unit=self.dim(i).unit,
|
interval_closed=(True, True),
|
||||||
)
|
): [str(j) for j in range(elements)]
|
||||||
for i, dim_sym in enumerate(dim_syms)
|
for i, elements in enumerate(shape)
|
||||||
},
|
}
|
||||||
|
|
||||||
|
return ct.InfoFlow(
|
||||||
|
dims=dims,
|
||||||
output=sim_symbols.SimSymbol(
|
output=sim_symbols.SimSymbol(
|
||||||
sym_name=props['output_name'],
|
sym_name=props['output_name'],
|
||||||
mathtype=props['output_mathtype'],
|
mathtype=props['output_mathtype'],
|
||||||
physical_type=props['output_physical_type'],
|
physical_type=props['output_physical_type'],
|
||||||
|
unit=props['output_unit'],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
|
@ -259,9 +259,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
|
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
|
||||||
return ct.RangeFlow(
|
return ct.RangeFlow(
|
||||||
start=props['freq_range'][0] / spux.THz,
|
start=spu.scale_to_unit(['freq_range'][0], spux.THz),
|
||||||
stop=props['freq_range'][1] / spux.THz,
|
stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
|
||||||
steps=0,
|
|
||||||
scaling=ct.ScalingMode.Lin,
|
scaling=ct.ScalingMode.Lin,
|
||||||
unit=spux.THz,
|
unit=spux.THz,
|
||||||
)
|
)
|
||||||
|
@ -273,9 +272,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
||||||
)
|
)
|
||||||
def compute_valid_wls_lazy(self, props) -> sp.Expr:
|
def compute_valid_wls_lazy(self, props) -> sp.Expr:
|
||||||
return ct.RangeFlow(
|
return ct.RangeFlow(
|
||||||
start=props['wl_range'][0] / spu.nm,
|
start=spu.scale_to_unit(['wl_range'][0], spu.nm),
|
||||||
stop=props['wl_range'][0] / spu.nm,
|
stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
|
||||||
steps=0,
|
|
||||||
scaling=ct.ScalingMode.Lin,
|
scaling=ct.ScalingMode.Lin,
|
||||||
unit=spu.nm,
|
unit=spu.nm,
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,31 +73,90 @@ class ViewerNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
|
auto_expr: bool = bl_cache.BLField(True)
|
||||||
auto_plot: bool = bl_cache.BLField(False)
|
debug_mode: bool = bl_cache.BLField(False)
|
||||||
|
|
||||||
|
# Debug Mode
|
||||||
|
console_print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
|
||||||
|
auto_plot: bool = bl_cache.BLField(True)
|
||||||
auto_3d_preview: bool = bl_cache.BLField(True)
|
auto_3d_preview: bool = bl_cache.BLField(True)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties: Computed FlowKinds
|
||||||
|
####################
|
||||||
|
@events.on_value_changed(
|
||||||
|
socket_name='Any',
|
||||||
|
)
|
||||||
|
def on_input_changed(self) -> None:
|
||||||
|
self.input_flow = bl_cache.Signal.InvalidateCache
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property()
|
||||||
|
def input_flow(self) -> dict[ct.FlowKind, typ.Any | None]:
|
||||||
|
input_flow = {}
|
||||||
|
|
||||||
|
for flow_kind in list(ct.FlowKind):
|
||||||
|
flow = self._compute_input('Any', kind=flow_kind)
|
||||||
|
has_flow = not ct.FlowSignal.check(flow)
|
||||||
|
|
||||||
|
if has_flow:
|
||||||
|
input_flow |= {flow_kind: flow}
|
||||||
|
else:
|
||||||
|
input_flow |= {flow_kind: None}
|
||||||
|
|
||||||
|
return input_flow
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Property: Input Expression String Lines
|
||||||
|
####################
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'input_flow'})
|
||||||
|
def input_expr_str_entries(self) -> list[list[str]] | None:
|
||||||
|
value = self.input_flow.get(ct.FlowKind.Value)
|
||||||
|
|
||||||
|
def sp_pretty(v: spux.SympyExpr) -> spux.SympyExpr:
|
||||||
|
## sp.pretty makes new lines and wreaks havoc.
|
||||||
|
return spux.sp_to_str(v.n(4))
|
||||||
|
|
||||||
|
if isinstance(value, spux.SympyType):
|
||||||
|
if isinstance(value, sp.MatrixBase):
|
||||||
|
return [
|
||||||
|
[sp_pretty(value[row, col]) for col in range(value.shape[1])]
|
||||||
|
for row in range(value.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
return [[sp_pretty(value)]]
|
||||||
|
return None
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||||
layout.prop(self, self.blfields['print_kind'], text='')
|
row = layout.row(align=True)
|
||||||
|
|
||||||
|
# Debug Mode On/Off
|
||||||
|
row.prop(self, self.blfields['debug_mode'], text='Debug', toggle=True)
|
||||||
|
|
||||||
|
# Automatic Expression Printing
|
||||||
|
row.prop(self, self.blfields['auto_expr'], text='Expr', toggle=True)
|
||||||
|
|
||||||
|
# Debug Mode Operators
|
||||||
|
if self.debug_mode:
|
||||||
|
layout.prop(self, self.blfields['console_print_kind'], text='')
|
||||||
|
|
||||||
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||||
|
# Live Expression
|
||||||
|
if self.debug_mode:
|
||||||
|
layout.operator(ConsoleViewOperator.bl_idname, text='Console Print')
|
||||||
|
|
||||||
split = layout.split(factor=0.4)
|
split = layout.split(factor=0.4)
|
||||||
|
|
||||||
# Split LHS
|
# Split LHS
|
||||||
col = split.column(align=False)
|
col = split.column(align=False)
|
||||||
col.label(text='Console')
|
|
||||||
col.label(text='Plot')
|
col.label(text='Plot')
|
||||||
col.label(text='3D')
|
col.label(text='3D')
|
||||||
|
|
||||||
# Split RHS
|
# Split RHS
|
||||||
col = split.column(align=False)
|
col = split.column(align=False)
|
||||||
|
|
||||||
## Console Options
|
|
||||||
col.operator(ConsoleViewOperator.bl_idname, text='Print')
|
|
||||||
|
|
||||||
## Plot Options
|
## Plot Options
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
|
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
|
||||||
|
@ -109,7 +168,43 @@ class ViewerNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
## 3D Preview Options
|
## 3D Preview Options
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.prop(self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True)
|
row.prop(
|
||||||
|
self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||||
|
# Live Expression
|
||||||
|
if self.auto_expr and self.input_expr_str_entries is not None:
|
||||||
|
box = layout.box()
|
||||||
|
|
||||||
|
expr_rows = len(self.input_expr_str_entries)
|
||||||
|
expr_cols = len(self.input_expr_str_entries[0])
|
||||||
|
shape_str = (
|
||||||
|
f'({expr_rows}×{expr_cols})'
|
||||||
|
if expr_rows != 1 or expr_cols != 1
|
||||||
|
else '(Scalar)'
|
||||||
|
)
|
||||||
|
|
||||||
|
row = box.row()
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text=f'Expr {shape_str}')
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(self.input_expr_str_entries) == 1
|
||||||
|
and len(self.input_expr_str_entries[0]) == 1
|
||||||
|
):
|
||||||
|
row = box.row()
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text=self.input_expr_str_entries[0][0])
|
||||||
|
else:
|
||||||
|
grid = box.grid_flow(
|
||||||
|
row_major=True,
|
||||||
|
columns=len(self.input_expr_str_entries[0]),
|
||||||
|
align=True,
|
||||||
|
)
|
||||||
|
for row in self.input_expr_str_entries:
|
||||||
|
for entry in row:
|
||||||
|
grid.label(text=entry)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Methods
|
# - Methods
|
||||||
|
@ -119,7 +214,7 @@ class ViewerNode(base.MaxwellSimNode):
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info('Printing to Console')
|
log.info('Printing to Console')
|
||||||
data = self._compute_input('Any', kind=self.print_kind, optional=True)
|
data = self._compute_input('Any', kind=self.console_print_kind, optional=True)
|
||||||
|
|
||||||
if isinstance(data, spux.SympyType):
|
if isinstance(data, spux.SympyType):
|
||||||
console.print(sp.pretty(data, use_unicode=True))
|
console.print(sp.pretty(data, use_unicode=True))
|
||||||
|
|
|
@ -40,6 +40,8 @@ _max_e_socket_def = sockets.ExprSocketDef(
|
||||||
)
|
)
|
||||||
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
|
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
|
||||||
|
|
||||||
|
t_ps = sim_symbols.t(spu.picosecond)
|
||||||
|
|
||||||
|
|
||||||
class TemporalShapeNode(base.MaxwellSimNode):
|
class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
"""Declare a source-time dependence for use in simulation source nodes."""
|
"""Declare a source-time dependence for use in simulation source nodes."""
|
||||||
|
@ -82,8 +84,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
||||||
default_steps=100,
|
default_steps=100,
|
||||||
),
|
),
|
||||||
'Envelope': sockets.ExprSocketDef(
|
'Envelope': sockets.ExprSocketDef(
|
||||||
default_symbols=[sim_symbols.t],
|
default_symbols=[t_ps],
|
||||||
default_value=10 * sim_symbols.t.sp_symbol,
|
default_value=10 * t_ps.sp_symbol,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -593,6 +593,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
||||||
ValueError: When referencing a socket that's meant to be directly referenced.
|
ValueError: When referencing a socket that's meant to be directly referenced.
|
||||||
"""
|
"""
|
||||||
kind_data_map = {
|
kind_data_map = {
|
||||||
|
ct.FlowKind.Capabilities: lambda: self.capabilities,
|
||||||
ct.FlowKind.Value: lambda: self.value,
|
ct.FlowKind.Value: lambda: self.value,
|
||||||
ct.FlowKind.Array: lambda: self.array,
|
ct.FlowKind.Array: lambda: self.array,
|
||||||
ct.FlowKind.Func: lambda: self.lazy_func,
|
ct.FlowKind.Func: lambda: self.lazy_func,
|
||||||
|
|
|
@ -111,29 +111,38 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
bl_label = 'Expr'
|
bl_label = 'Expr'
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Socket Interface
|
||||||
####################
|
####################
|
||||||
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
|
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
|
||||||
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
|
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
|
||||||
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
|
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
|
||||||
|
|
||||||
# Symbols
|
####################
|
||||||
|
# - Symbols
|
||||||
|
####################
|
||||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||||
sim_symbols.SimSymbolName.Expr
|
sim_symbols.SimSymbolName.Expr
|
||||||
)
|
)
|
||||||
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
||||||
|
|
||||||
@property
|
|
||||||
def symbols(self) -> set[sp.Symbol]:
|
|
||||||
"""Current symbols as an unordered set."""
|
|
||||||
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
|
|
||||||
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
|
||||||
|
"""Sympy symbols as an unordered set."""
|
||||||
|
return {sim_symbol.sp_symbol_matsym for sim_symbol in self.symbols}
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||||
|
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||||
"""Current symbols as a sorted list."""
|
"""Current symbols as a sorted list."""
|
||||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||||
|
|
||||||
# Unit
|
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||||
|
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
|
||||||
|
"""Computes `sympy` symbols from `self.sorted_symbols`."""
|
||||||
|
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Units
|
||||||
|
####################
|
||||||
active_unit: enum.StrEnum = bl_cache.BLField(
|
active_unit: enum.StrEnum = bl_cache.BLField(
|
||||||
enum_cb=lambda self, _: self.search_valid_units(),
|
enum_cb=lambda self, _: self.search_valid_units(),
|
||||||
cb_depends_on={'physical_type'},
|
cb_depends_on={'physical_type'},
|
||||||
|
@ -148,6 +157,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@bl_cache.cached_bl_property(depends_on={'active_unit'})
|
||||||
|
def unit(self) -> spux.Unit | None:
|
||||||
|
"""Gets the current active unit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current active `sympy` unit.
|
||||||
|
|
||||||
|
If the socket expression is unitless, this returns `None`.
|
||||||
|
"""
|
||||||
|
if self.active_unit is not None:
|
||||||
|
return spux.unit_str_to_unit(self.active_unit)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unit_factor(self) -> spux.Unit | None:
|
||||||
|
return sp.Integer(1) if self.unit is None else self.unit
|
||||||
|
|
||||||
|
prev_unit: str | None = bl_cache.BLField(None)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - UI Values
|
||||||
|
####################
|
||||||
# UI: Value
|
# UI: Value
|
||||||
## Expression
|
## Expression
|
||||||
raw_value_spstr: str = bl_cache.BLField('0.0')
|
raw_value_spstr: str = bl_cache.BLField('0.0')
|
||||||
|
@ -186,6 +218,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
# UI: Info
|
# UI: Info
|
||||||
|
show_name_selector: bool = bl_cache.BLField(False)
|
||||||
show_func_ui: bool = bl_cache.BLField(True)
|
show_func_ui: bool = bl_cache.BLField(True)
|
||||||
show_info_columns: bool = bl_cache.BLField(False)
|
show_info_columns: bool = bl_cache.BLField(False)
|
||||||
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
|
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
|
||||||
|
@ -207,25 +240,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
def raw_max_sp(self) -> spux.SympyExpr:
|
def raw_max_sp(self) -> spux.SympyExpr:
|
||||||
return self._parse_expr_str(self.raw_max_spstr)
|
return self._parse_expr_str(self.raw_max_spstr)
|
||||||
|
|
||||||
####################
|
|
||||||
# - Computed Unit
|
|
||||||
####################
|
|
||||||
@bl_cache.cached_bl_property(depends_on={'active_unit'})
|
|
||||||
def unit(self) -> spux.Unit | None:
|
|
||||||
"""Gets the current active unit.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The current active `sympy` unit.
|
|
||||||
|
|
||||||
If the socket expression is unitless, this returns `None`.
|
|
||||||
"""
|
|
||||||
if self.active_unit is not None:
|
|
||||||
return spux.unit_str_to_unit(self.active_unit)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
prev_unit: str | None = bl_cache.BLField(None)
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Prop-Change Callback
|
# - Prop-Change Callback
|
||||||
####################
|
####################
|
||||||
|
@ -272,7 +286,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Parse Symbols
|
# Parse Symbols
|
||||||
if expr.free_symbols and not expr.free_symbols.issubset(self.symbols):
|
if expr.free_symbols and not expr.free_symbols.issubset(self.sp_symbols):
|
||||||
msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}'
|
msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -320,7 +334,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
"""
|
"""
|
||||||
expr = sp.sympify(
|
expr = sp.sympify(
|
||||||
expr_spstr,
|
expr_spstr,
|
||||||
locals={sym.name: sym for sym in self.symbols},
|
locals={sym.name: sym.sp_symbol_matsym for sym in self.symbols},
|
||||||
strict=False,
|
strict=False,
|
||||||
convert_xor=True,
|
convert_xor=True,
|
||||||
).subs(spux.UNIT_BY_SYMBOL)
|
).subs(spux.UNIT_BY_SYMBOL)
|
||||||
|
@ -562,11 +576,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
if self.symbols:
|
if self.symbols:
|
||||||
return ct.FuncFlow(
|
return ct.FuncFlow(
|
||||||
func=sp.lambdify(
|
func=sp.lambdify(
|
||||||
self.sorted_symbols,
|
self.sorted_sp_symbols,
|
||||||
spux.scale_to_unit(self.value, self.unit),
|
spux.strip_unit_system(self.value),
|
||||||
'jax',
|
'jax',
|
||||||
),
|
),
|
||||||
func_args=[spux.MathType.from_expr(sym) for sym in self.sorted_symbols],
|
func_args=list(self.sorted_symbols),
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -578,7 +592,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
return ct.FuncFlow(
|
return ct.FuncFlow(
|
||||||
func=lambda v: v,
|
func=lambda v: v,
|
||||||
func_args=[
|
func_args=[
|
||||||
self.physical_type if self.physical_type is not None else self.mathtype
|
sim_symbols.SimSymbol.from_expr(
|
||||||
|
sim_symbols.SimSymbolName.Constant, self.value, self.unit_factor
|
||||||
|
)
|
||||||
],
|
],
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
@ -597,8 +613,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
## -> NOTE: func_args must have the same symbol order as was lambdified.
|
## -> NOTE: func_args must have the same symbol order as was lambdified.
|
||||||
if self.symbols:
|
if self.symbols:
|
||||||
return ct.ParamsFlow(
|
return ct.ParamsFlow(
|
||||||
func_args=self.sorted_symbols,
|
func_args=[sym.sp_symbol_phy for sym in self.sorted_symbols],
|
||||||
symbols=self.symbols,
|
symbols=self.sorted_symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Constant
|
# Constant
|
||||||
|
@ -618,24 +634,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||||
"""
|
"""
|
||||||
output_sim_sym = (
|
output_sym = sim_symbols.SimSymbol(
|
||||||
sim_symbols.SimSymbol(
|
|
||||||
sym_name=self.output_name,
|
sym_name=self.output_name,
|
||||||
mathtype=self.mathtype,
|
mathtype=self.mathtype,
|
||||||
physical_type=self.physical_type,
|
physical_type=self.physical_type,
|
||||||
unit=self.unit,
|
unit=self.unit,
|
||||||
rows=self.size.rows,
|
rows=self.size.rows,
|
||||||
cols=self.size.cols,
|
cols=self.size.cols,
|
||||||
),
|
|
||||||
)
|
|
||||||
if self.symbols:
|
|
||||||
return ct.InfoFlow(
|
|
||||||
dims={sim_sym: None for sim_sym in self.active_symbols},
|
|
||||||
output=output_sim_sym,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Constant
|
# Constant
|
||||||
return ct.InfoFlow(output=output_sim_sym)
|
## -> The input SimSymbols become continuous dimensional indices.
|
||||||
|
## -> All domain validity information is defined on the SimSymbol keys.
|
||||||
|
if self.symbols:
|
||||||
|
return ct.InfoFlow(
|
||||||
|
dims={sym: None for sym in self.sorted_symbols},
|
||||||
|
output=output_sym,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constant
|
||||||
|
## -> We only need the output symbol to describe the raw data.
|
||||||
|
return ct.InfoFlow(output=output_sym)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKind: Capabilities
|
# - FlowKind: Capabilities
|
||||||
|
@ -645,6 +664,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
return ct.CapabilitiesFlow(
|
return ct.CapabilitiesFlow(
|
||||||
socket_type=self.socket_type,
|
socket_type=self.socket_type,
|
||||||
active_kind=self.active_kind,
|
active_kind=self.active_kind,
|
||||||
|
allow_out_to_in={ct.FlowKind.Func: ct.FlowKind.Value},
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -795,7 +815,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
col = split.column()
|
col = split.column()
|
||||||
col.alignment = 'RIGHT'
|
col.alignment = 'RIGHT'
|
||||||
for sym in self.symbols:
|
for sym in self.symbols:
|
||||||
col.label(text=spux.pretty_symbol(sym))
|
col.label(text=sym.def_label)
|
||||||
|
|
||||||
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
|
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
|
||||||
"""Draw the socket body for a simple, uniform range of values between two values/expressions.
|
"""Draw the socket body for a simple, uniform range of values between two values/expressions.
|
||||||
|
@ -840,14 +860,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
Uses `draw_value` to draw the base UI
|
Uses `draw_value` to draw the base UI
|
||||||
"""
|
"""
|
||||||
if self.show_func_ui:
|
if self.show_func_ui:
|
||||||
# Output Name Selector
|
|
||||||
## -> The name of the output
|
|
||||||
col.prop(self, self.blfields['output_name'], text='')
|
|
||||||
|
|
||||||
# Physical Type Selector
|
|
||||||
## -> Determines whether/which unit-dropdown will be shown.
|
|
||||||
col.prop(self, self.blfields['physical_type'], text='')
|
|
||||||
|
|
||||||
# Non-Symbolic: Size/Mathtype Selector
|
# Non-Symbolic: Size/Mathtype Selector
|
||||||
## -> Symbols imply str expr input.
|
## -> Symbols imply str expr input.
|
||||||
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
|
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
|
||||||
|
@ -861,15 +873,30 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
## -> Draws the UI appropriate for the above choice of constraints.
|
## -> Draws the UI appropriate for the above choice of constraints.
|
||||||
self.draw_value(col)
|
self.draw_value(col)
|
||||||
|
|
||||||
|
# Physical Type Selector
|
||||||
|
## -> Determines whether/which unit-dropdown will be shown.
|
||||||
|
col.prop(self, self.blfields['physical_type'], text='')
|
||||||
|
|
||||||
# Symbol UI
|
# Symbol UI
|
||||||
## -> Draws the UI appropriate for the above choice of constraints.
|
## -> Draws the UI appropriate for the above choice of constraints.
|
||||||
## -> TODO
|
## -> TODO
|
||||||
|
|
||||||
|
# Output Name Selector
|
||||||
|
## -> The name of the output
|
||||||
|
if self.show_name_selector:
|
||||||
|
row = col.row()
|
||||||
|
row.prop(self, self.blfields['output_name'], text='Name')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI: InfoFlow
|
# - UI: InfoFlow
|
||||||
####################
|
####################
|
||||||
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
||||||
if self.active_kind == ct.FlowKind.Func and self.show_info_columns:
|
"""Visualize the `InfoFlow` information passing through the socket."""
|
||||||
|
if (
|
||||||
|
self.active_kind == ct.FlowKind.Func
|
||||||
|
and self.show_info_columns
|
||||||
|
and self.is_linked
|
||||||
|
):
|
||||||
row = col.row()
|
row = col.row()
|
||||||
box = row.box()
|
box = row.box()
|
||||||
grid = box.grid_flow(
|
grid = box.grid_flow(
|
||||||
|
@ -881,38 +908,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dimensions
|
# Dimensions
|
||||||
for dim in info.dims:
|
for dim_name_pretty, dim_label_info in info.dim_labels.items():
|
||||||
dim_idx = info.dims[dim]
|
grid.label(text=dim_name_pretty)
|
||||||
grid.label(text=dim.name_pretty)
|
|
||||||
if InfoDisplayCol.Length in self.info_columns:
|
if InfoDisplayCol.Length in self.info_columns:
|
||||||
grid.label(text=str(len(dim_idx)))
|
grid.label(text=dim_label_info['length'])
|
||||||
if InfoDisplayCol.MathType in self.info_columns:
|
if InfoDisplayCol.MathType in self.info_columns:
|
||||||
grid.label(text=spux.MathType.to_str(dim_idx.mathtype))
|
grid.label(text=dim_label_info['mathtype'])
|
||||||
if InfoDisplayCol.Unit in self.info_columns:
|
if InfoDisplayCol.Unit in self.info_columns:
|
||||||
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
grid.label(text=dim_label_info['unit'])
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
grid.label(text=info.output.name_pretty)
|
grid.label(text=info.output.name_pretty)
|
||||||
if InfoDisplayCol.Length in self.info_columns:
|
if InfoDisplayCol.Length in self.info_columns:
|
||||||
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
||||||
if InfoDisplayCol.MathType in self.info_columns:
|
if InfoDisplayCol.MathType in self.info_columns:
|
||||||
grid.label(
|
grid.label(text=info.output.def_label)
|
||||||
text=(
|
|
||||||
spux.MathType.to_str(info.output.mathtype)
|
|
||||||
+ (
|
|
||||||
'ˣ'.join(
|
|
||||||
[
|
|
||||||
unicode_superscript(out_axis)
|
|
||||||
for out_axis in info.output.shape
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if info.output.shape
|
|
||||||
else ''
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if InfoDisplayCol.Unit in self.info_columns:
|
if InfoDisplayCol.Unit in self.info_columns:
|
||||||
grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
|
grid.label(text=info.output.unit_label)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -926,7 +938,7 @@ class ExprSocketDef(base.SocketDef):
|
||||||
ct.FlowKind.Array,
|
ct.FlowKind.Array,
|
||||||
ct.FlowKind.Func,
|
ct.FlowKind.Func,
|
||||||
] = ct.FlowKind.Value
|
] = ct.FlowKind.Value
|
||||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
|
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
|
||||||
|
|
||||||
# Socket Interface
|
# Socket Interface
|
||||||
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
||||||
|
@ -948,9 +960,15 @@ class ExprSocketDef(base.SocketDef):
|
||||||
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
|
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
|
||||||
|
|
||||||
# UI
|
# UI
|
||||||
|
show_name_selector: bool = False
|
||||||
show_func_ui: bool = True
|
show_func_ui: bool = True
|
||||||
show_info_columns: bool = False
|
show_info_columns: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
|
||||||
|
"""Default symbols as an unordered set."""
|
||||||
|
return {sym.sp_symbol_matsym for sym in self.default_symbols}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Parse Unit and/or Physical Type
|
# - Parse Unit and/or Physical Type
|
||||||
####################
|
####################
|
||||||
|
@ -1149,6 +1167,7 @@ class ExprSocketDef(base.SocketDef):
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Coerce from Infinite
|
# Coerce from Infinite
|
||||||
|
if isinstance(bound, spux.SympyType):
|
||||||
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
|
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
|
||||||
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
|
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
|
||||||
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
|
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
|
||||||
|
@ -1194,9 +1213,9 @@ class ExprSocketDef(base.SocketDef):
|
||||||
def symbols_value(self) -> typ.Self:
|
def symbols_value(self) -> typ.Self:
|
||||||
if (
|
if (
|
||||||
self.default_value.free_symbols
|
self.default_value.free_symbols
|
||||||
and not self.default_value.free_symbols.issubset(self.symbols)
|
and not self.default_value.free_symbols.issubset(self.sp_symbols)
|
||||||
):
|
):
|
||||||
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.symbols}'
|
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.sp_symbols}'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -1227,7 +1246,7 @@ class ExprSocketDef(base.SocketDef):
|
||||||
bl_socket.size = self.size
|
bl_socket.size = self.size
|
||||||
bl_socket.mathtype = self.mathtype
|
bl_socket.mathtype = self.mathtype
|
||||||
bl_socket.physical_type = self.physical_type
|
bl_socket.physical_type = self.physical_type
|
||||||
bl_socket.active_symbols = self.symbols
|
bl_socket.symbols = self.default_symbols
|
||||||
|
|
||||||
# FlowKind.Value
|
# FlowKind.Value
|
||||||
## -> We must take units into account when setting bl_socket.value
|
## -> We must take units into account when setting bl_socket.value
|
||||||
|
@ -1252,6 +1271,7 @@ class ExprSocketDef(base.SocketDef):
|
||||||
# UI
|
# UI
|
||||||
bl_socket.show_func_ui = self.show_func_ui
|
bl_socket.show_func_ui = self.show_func_ui
|
||||||
bl_socket.show_info_columns = self.show_info_columns
|
bl_socket.show_info_columns = self.show_info_columns
|
||||||
|
bl_socket.show_name_selector = self.show_name_selector
|
||||||
|
|
||||||
# Info Draw
|
# Info Draw
|
||||||
bl_socket.use_info_draw = True
|
bl_socket.use_info_draw = True
|
||||||
|
|
|
@ -389,14 +389,15 @@ class BLField:
|
||||||
|
|
||||||
Reset by setting the descriptor to `Signal.ResetStrSearch`.
|
Reset by setting the descriptor to `Signal.ResetStrSearch`.
|
||||||
"""
|
"""
|
||||||
cached_items = self.bl_prop_str_search.read_nonpersist(_self)
|
return self.str_cb(_self, context, edit_text)
|
||||||
if cached_items is not Signal.CacheNotReady:
|
# cached_items = self.bl_prop_str_search.read_nonpersist(_self)
|
||||||
if cached_items is Signal.CacheEmpty:
|
# if cached_items is not Signal.CacheNotReady:
|
||||||
computed_items = self.str_cb(_self, context, edit_text)
|
# if cached_items is Signal.CacheEmpty:
|
||||||
self.bl_prop_str_search.write_nonpersist(_self, computed_items)
|
# computed_items = self.str_cb(_self, context, edit_text)
|
||||||
return computed_items
|
# self.bl_prop_str_search.write_nonpersist(_self, computed_items)
|
||||||
return cached_items
|
# return computed_items
|
||||||
return []
|
# return cached_items
|
||||||
|
# return []
|
||||||
|
|
||||||
def safe_enum_cb(
|
def safe_enum_cb(
|
||||||
self, _self: bl_instance.BLInstance, context: bpy.types.Context
|
self, _self: bl_instance.BLInstance, context: bpy.types.Context
|
||||||
|
|
|
@ -28,11 +28,13 @@ Attributes:
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
|
import sys
|
||||||
import typing as typ
|
import typing as typ
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jaxtyping as jtyp
|
||||||
import pydantic as pyd
|
import pydantic as pyd
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import sympy.physics.units as spu
|
import sympy.physics.units as spu
|
||||||
|
@ -144,6 +146,21 @@ class MathType(enum.StrEnum):
|
||||||
complex: MathType.Complex,
|
complex: MathType.Complex,
|
||||||
}[dtype]
|
}[dtype]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
|
||||||
|
"""Deduce the MathType corresponding to a JAX array.
|
||||||
|
|
||||||
|
We go about this by leveraging that:
|
||||||
|
- `data` is of a homogeneous type.
|
||||||
|
- `data.item(0)` returns a single element of the array w/pure-python type.
|
||||||
|
|
||||||
|
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Should also work with numpy arrays.
|
||||||
|
"""
|
||||||
|
return MathType.from_pytype(type(data.item(0)))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
|
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
|
||||||
if isinstance(obj, bool | int | Fraction | float | complex):
|
if isinstance(obj, bool | int | Fraction | float | complex):
|
||||||
|
@ -173,6 +190,39 @@ class MathType(enum.StrEnum):
|
||||||
MT.Complex: sp.Complexes,
|
MT.Complex: sp.Complexes,
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inf_finite(self) -> type:
|
||||||
|
"""Opinionated finite representation of "infinity" within this `MathType`.
|
||||||
|
|
||||||
|
These are chosen using `sys.maxsize` and `sys.float_info`.
|
||||||
|
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
|
||||||
|
|
||||||
|
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
|
||||||
|
|
||||||
|
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
|
||||||
|
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
|
||||||
|
|
||||||
|
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
|
||||||
|
"""
|
||||||
|
MT = MathType
|
||||||
|
Z = MT.Integer
|
||||||
|
R = MT.Integer
|
||||||
|
return {
|
||||||
|
MT.Integer: (-sys.maxsize, sys.maxsize),
|
||||||
|
MT.Rational: (
|
||||||
|
Fraction(Z.inf_finite[0], 1),
|
||||||
|
Fraction(Z.inf_finite[1], 1),
|
||||||
|
),
|
||||||
|
MT.Real: -(sys.float_info.min, sys.float_info.max),
|
||||||
|
MT.Complex: (
|
||||||
|
complex(R.inf_finite[0], R.inf_finite[0]),
|
||||||
|
complex(R.inf_finite[1], R.inf_finite[1]),
|
||||||
|
),
|
||||||
|
}[self]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sp_symbol_a(self) -> type:
|
def sp_symbol_a(self) -> type:
|
||||||
MT = MathType
|
MT = MathType
|
||||||
|
@ -192,6 +242,10 @@ class MathType(enum.StrEnum):
|
||||||
MathType.Complex: 'ℂ',
|
MathType.Complex: 'ℂ',
|
||||||
}[value]
|
}[value]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label_pretty(self) -> str:
|
||||||
|
return MathType.to_str(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_name(value: typ.Self) -> str:
|
def to_name(value: typ.Self) -> str:
|
||||||
return MathType.to_str(value)
|
return MathType.to_str(value)
|
||||||
|
@ -819,14 +873,15 @@ def sp_to_str(sp_obj: SympyExpr) -> str:
|
||||||
A string representing the expression for human use.
|
A string representing the expression for human use.
|
||||||
_The string is not re-encodable to the expression._
|
_The string is not re-encodable to the expression._
|
||||||
"""
|
"""
|
||||||
|
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
|
||||||
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
|
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
|
||||||
|
|
||||||
|
|
||||||
def pretty_symbol(sym: sp.Symbol) -> str:
|
def pretty_symbol(sym: sp.Symbol) -> str:
|
||||||
return f'{sym.name} ∈ ' + (
|
return f'{sym.name} ∈ ' + (
|
||||||
'ℂ'
|
'ℤ'
|
||||||
if sym.is_complex
|
if sym.is_integer
|
||||||
else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?'))
|
else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?'))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1039,20 +1094,24 @@ class PhysicalType(enum.StrEnum):
|
||||||
PT.LumIntensity: spu.candela,
|
PT.LumIntensity: spu.candela,
|
||||||
PT.LumFlux: spu.candela * spu.steradian,
|
PT.LumFlux: spu.candela * spu.steradian,
|
||||||
PT.Illuminance: spu.candela / spu.meter**2,
|
PT.Illuminance: spu.candela / spu.meter**2,
|
||||||
# Optics
|
|
||||||
PT.OrdinaryWaveVector: terahertz,
|
|
||||||
PT.AngularWaveVector: spu.radian * terahertz,
|
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def valid_units(self) -> list[Unit]:
|
def valid_units(self) -> list[Unit]:
|
||||||
|
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The order in which valid units are declared is the exact same order that UI dropdowns display them.
|
||||||
|
|
||||||
|
**Altering the order of units breaks backwards compatibility**.
|
||||||
|
"""
|
||||||
PT = PhysicalType
|
PT = PhysicalType
|
||||||
return {
|
return {
|
||||||
PT.NonPhysical: [None],
|
PT.NonPhysical: [None],
|
||||||
# Global
|
# Global
|
||||||
PT.Time: [
|
PT.Time: [
|
||||||
femtosecond,
|
|
||||||
spu.picosecond,
|
spu.picosecond,
|
||||||
|
femtosecond,
|
||||||
spu.nanosecond,
|
spu.nanosecond,
|
||||||
spu.microsecond,
|
spu.microsecond,
|
||||||
spu.millisecond,
|
spu.millisecond,
|
||||||
|
@ -1070,11 +1129,11 @@ class PhysicalType(enum.StrEnum):
|
||||||
],
|
],
|
||||||
PT.Freq: (
|
PT.Freq: (
|
||||||
_valid_freqs := [
|
_valid_freqs := [
|
||||||
|
terahertz,
|
||||||
spu.hertz,
|
spu.hertz,
|
||||||
kilohertz,
|
kilohertz,
|
||||||
megahertz,
|
megahertz,
|
||||||
gigahertz,
|
gigahertz,
|
||||||
terahertz,
|
|
||||||
petahertz,
|
petahertz,
|
||||||
exahertz,
|
exahertz,
|
||||||
]
|
]
|
||||||
|
@ -1083,10 +1142,10 @@ class PhysicalType(enum.StrEnum):
|
||||||
# Cartesian
|
# Cartesian
|
||||||
PT.Length: (
|
PT.Length: (
|
||||||
_valid_lens := [
|
_valid_lens := [
|
||||||
|
spu.micrometer,
|
||||||
|
spu.nanometer,
|
||||||
spu.picometer,
|
spu.picometer,
|
||||||
spu.angstrom,
|
spu.angstrom,
|
||||||
spu.nanometer,
|
|
||||||
spu.micrometer,
|
|
||||||
spu.millimeter,
|
spu.millimeter,
|
||||||
spu.centimeter,
|
spu.centimeter,
|
||||||
spu.meter,
|
spu.meter,
|
||||||
|
@ -1102,24 +1161,24 @@ class PhysicalType(enum.StrEnum):
|
||||||
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
|
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
|
||||||
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
|
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
|
||||||
PT.Mass: [
|
PT.Mass: [
|
||||||
|
spu.kilogram,
|
||||||
spu.electron_rest_mass,
|
spu.electron_rest_mass,
|
||||||
spu.dalton,
|
spu.dalton,
|
||||||
spu.microgram,
|
spu.microgram,
|
||||||
spu.milligram,
|
spu.milligram,
|
||||||
spu.gram,
|
spu.gram,
|
||||||
spu.kilogram,
|
|
||||||
spu.metric_ton,
|
spu.metric_ton,
|
||||||
],
|
],
|
||||||
PT.Force: [
|
PT.Force: [
|
||||||
spu.kg * spu.meter / spu.second**2,
|
|
||||||
nanonewton,
|
|
||||||
micronewton,
|
micronewton,
|
||||||
|
nanonewton,
|
||||||
millinewton,
|
millinewton,
|
||||||
spu.newton,
|
spu.newton,
|
||||||
|
spu.kg * spu.meter / spu.second**2,
|
||||||
],
|
],
|
||||||
PT.Pressure: [
|
PT.Pressure: [
|
||||||
millibar,
|
|
||||||
spu.bar,
|
spu.bar,
|
||||||
|
millibar,
|
||||||
spu.pascal,
|
spu.pascal,
|
||||||
hectopascal,
|
hectopascal,
|
||||||
spu.atmosphere,
|
spu.atmosphere,
|
||||||
|
@ -1129,8 +1188,8 @@ class PhysicalType(enum.StrEnum):
|
||||||
],
|
],
|
||||||
# Energy
|
# Energy
|
||||||
PT.Work: [
|
PT.Work: [
|
||||||
spu.electronvolt,
|
|
||||||
spu.joule,
|
spu.joule,
|
||||||
|
spu.electronvolt,
|
||||||
],
|
],
|
||||||
PT.Power: [
|
PT.Power: [
|
||||||
spu.watt,
|
spu.watt,
|
||||||
|
@ -1194,18 +1253,17 @@ class PhysicalType(enum.StrEnum):
|
||||||
PT.Illuminance: [
|
PT.Illuminance: [
|
||||||
spu.candela / spu.meter**2,
|
spu.candela / spu.meter**2,
|
||||||
],
|
],
|
||||||
# Optics
|
|
||||||
PT.OrdinaryWaveVector: _valid_freqs,
|
|
||||||
PT.AngularWaveVector: [spu.radian * _unit for _unit in _valid_freqs],
|
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_unit(unit: Unit) -> list[Unit]:
|
def from_unit(unit: Unit, optional: bool = False) -> list[Unit] | None:
|
||||||
for physical_type in list(PhysicalType):
|
for physical_type in list(PhysicalType):
|
||||||
if unit in physical_type.valid_units:
|
if unit in physical_type.valid_units:
|
||||||
return physical_type
|
return physical_type
|
||||||
## TODO: Optimize
|
## TODO: Optimize
|
||||||
|
|
||||||
|
if optional:
|
||||||
|
return None
|
||||||
msg = f'Could not determine PhysicalType for {unit}'
|
msg = f'Could not determine PhysicalType for {unit}'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -1400,20 +1458,9 @@ def sympy_to_python(
|
||||||
####################
|
####################
|
||||||
# - Convert to Unit System
|
# - Convert to Unit System
|
||||||
####################
|
####################
|
||||||
def convert_to_unit_system(
|
def strip_unit_system(
|
||||||
sp_obj: SympyExpr, unit_system: UnitSystem | None
|
sp_obj: SympyExpr, unit_system: UnitSystem | None = None
|
||||||
) -> SympyExpr:
|
) -> SympyExpr:
|
||||||
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
|
|
||||||
if unit_system is None:
|
|
||||||
return sp_obj
|
|
||||||
|
|
||||||
return spu.convert_to(
|
|
||||||
sp_obj,
|
|
||||||
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
|
|
||||||
"""Strip units occurring in the given unit system from the expression.
|
"""Strip units occurring in the given unit system from the expression.
|
||||||
|
|
||||||
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
|
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
|
||||||
|
@ -1427,6 +1474,19 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> Symp
|
||||||
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
|
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_unit_system(
|
||||||
|
sp_obj: SympyExpr, unit_system: UnitSystem | None
|
||||||
|
) -> SympyExpr:
|
||||||
|
"""Convert an expression to the units of a given unit system."""
|
||||||
|
if unit_system is None:
|
||||||
|
return sp_obj
|
||||||
|
|
||||||
|
return spu.convert_to(
|
||||||
|
sp_obj,
|
||||||
|
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def scale_to_unit_system(
|
def scale_to_unit_system(
|
||||||
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
|
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
|
||||||
) -> int | float | complex | tuple | jax.Array:
|
) -> int | float | complex | tuple | jax.Array:
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import time
|
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -34,7 +33,7 @@ import seaborn as sns
|
||||||
from blender_maxwell import contracts as ct
|
from blender_maxwell import contracts as ct
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import logger
|
||||||
|
|
||||||
mplstyle.use('fast') ## TODO: Does this do anything?
|
# mplstyle.use('fast') ## TODO: Does this do anything?
|
||||||
sns.set_theme()
|
sns.set_theme()
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
@ -59,6 +58,9 @@ class Colormap(enum.StrEnum):
|
||||||
Viridis = enum.auto()
|
Viridis = enum.auto()
|
||||||
Grayscale = enum.auto()
|
Grayscale = enum.auto()
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - UI
|
||||||
|
####################
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_name(value: typ.Self) -> str:
|
def to_name(value: typ.Self) -> str:
|
||||||
return {
|
return {
|
||||||
|
@ -139,7 +141,9 @@ def rgba_image_from_2d_map(
|
||||||
####################
|
####################
|
||||||
@functools.lru_cache(maxsize=16)
|
@functools.lru_cache(maxsize=16)
|
||||||
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
||||||
fig = matplotlib.figure.Figure(figsize=[width_inches, height_inches], dpi=dpi)
|
fig = matplotlib.figure.Figure(
|
||||||
|
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
|
||||||
|
)
|
||||||
canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
|
canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
|
||||||
ax = fig.add_subplot()
|
ax = fig.add_subplot()
|
||||||
|
|
||||||
|
@ -152,66 +156,53 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
||||||
# - Plotters
|
# - Plotters
|
||||||
####################
|
####################
|
||||||
# (ℤ) -> ℝ
|
# (ℤ) -> ℝ
|
||||||
def plot_box_plot_1d(
|
def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
|
||||||
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
x_sym, y_sym = list(data.keys())
|
||||||
) -> None:
|
|
||||||
x_sym = info.last_dim
|
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
ax.boxplot([data])
|
ax.boxplot([data[y_sym]])
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(y_sym.plot_label)
|
||||||
|
|
||||||
|
|
||||||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
def plot_bar(data, ax: mpl_ax.Axis) -> None:
|
||||||
x_sym = info.last_dim
|
x_sym, heights_sym = list(data.keys())
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
p = ax.bar(info.dims[x_sym], data)
|
p = ax.bar(data[x_sym], data[heights_sym])
|
||||||
ax.bar_label(p, label_type='center')
|
ax.bar_label(p, label_type='center')
|
||||||
|
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(heights_sym.plot_label)
|
||||||
|
|
||||||
|
|
||||||
# (ℝ) -> ℝ
|
# (ℝ) -> ℝ
|
||||||
def plot_curve_2d(
|
def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
|
||||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
x_sym, y_sym = list(data.keys())
|
||||||
) -> None:
|
|
||||||
x_sym = info.last_dim
|
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
ax.plot(info.dims[x_sym].realize_array.values, data)
|
ax.plot(data[x_sym], data[y_sym])
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(y_sym.plot_label)
|
||||||
|
|
||||||
|
|
||||||
def plot_points_2d(
|
def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
|
||||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
x_sym, y_sym = list(data.keys())
|
||||||
) -> None:
|
|
||||||
x_sym = info.last_dim
|
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
ax.scatter(x_sym.realize_array.values, data)
|
ax.scatter(data[x_sym], data[y_sym])
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(y_sym.plot_label)
|
||||||
|
|
||||||
|
|
||||||
# (ℝ, ℤ) -> ℝ
|
# (ℝ, ℤ) -> ℝ
|
||||||
def plot_curves_2d(
|
def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
|
||||||
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
|
x_sym, label_sym, y_sym = list(data.keys())
|
||||||
) -> None:
|
|
||||||
x_sym = info.first_dim
|
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
for i, category in enumerate(info.dims[info.last_dim]):
|
for i, label in enumerate(data[label_sym]):
|
||||||
ax.plot(info.dims[x_sym], data[:, i], label=category)
|
ax.plot(data[x_sym], data[y_sym][:, i], label=label)
|
||||||
|
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(y_sym.plot_label)
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
@ -220,12 +211,10 @@ def plot_curves_2d(
|
||||||
def plot_filled_curves_2d(
|
def plot_filled_curves_2d(
|
||||||
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
|
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
|
||||||
) -> None:
|
) -> None:
|
||||||
x_sym = info.first_dim
|
x_sym, _, y_sym = list(data.keys())
|
||||||
y_sym = info.output
|
|
||||||
|
|
||||||
shared_x_idx = info.dims[info.last_dim]
|
ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
|
||||||
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
|
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
|
||||||
ax.set_xlabel(x_sym.plot_label)
|
ax.set_xlabel(x_sym.plot_label)
|
||||||
ax.set_xlabel(y_sym.plot_label)
|
ax.set_xlabel(y_sym.plot_label)
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
@ -235,11 +224,9 @@ def plot_filled_curves_2d(
|
||||||
def plot_heatmap_2d(
|
def plot_heatmap_2d(
|
||||||
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
|
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
|
||||||
) -> None:
|
) -> None:
|
||||||
x_sym = info.first_dim
|
x_sym, y_sym, c_sym = list(data.keys())
|
||||||
y_sym = info.last_dim
|
|
||||||
c_sym = info.output
|
|
||||||
|
|
||||||
heatmap = ax.imshow(data, aspect='equal', interpolation='none')
|
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
|
||||||
ax.figure.colorbar(heatmap, cax=ax)
|
ax.figure.colorbar(heatmap, cax=ax)
|
||||||
|
|
||||||
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
|
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
|
||||||
|
|
|
@ -99,6 +99,7 @@ class TypeID(enum.StrEnum):
|
||||||
SympyType: str = '!type=sympytype'
|
SympyType: str = '!type=sympytype'
|
||||||
SympyExpr: str = '!type=sympyexpr'
|
SympyExpr: str = '!type=sympyexpr'
|
||||||
SocketDef: str = '!type=socketdef'
|
SocketDef: str = '!type=socketdef'
|
||||||
|
SimSymbol: str = '!type=simsymbol'
|
||||||
ManagedObj: str = '!type=managedobj'
|
ManagedObj: str = '!type=managedobj'
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,11 +162,12 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
|
||||||
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
|
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
|
||||||
|
|
||||||
if hasattr(_type, 'parse_as_msgspec') and (
|
if hasattr(_type, 'parse_as_msgspec') and (
|
||||||
is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj]
|
is_representation(obj)
|
||||||
|
and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj, TypeID.SimSymbol]
|
||||||
):
|
):
|
||||||
return _type.parse_as_msgspec(obj)
|
return _type.parse_as_msgspec(obj)
|
||||||
|
|
||||||
msg = f'Can\'t decode "{obj}" to type {type(obj)}'
|
msg = f'can\'t decode "{obj}" to type {type(obj)}'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,36 +14,60 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import enum
|
import enum
|
||||||
|
import functools
|
||||||
|
import string
|
||||||
import sys
|
import sys
|
||||||
import typing as typ
|
import typing as typ
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
|
||||||
|
import jaxtyping as jtyp
|
||||||
|
import pydantic as pyd
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
import sympy.physics.units as spu
|
||||||
|
|
||||||
from . import extra_sympy_units as spux
|
from . import extra_sympy_units as spux
|
||||||
|
from . import logger, serialize
|
||||||
|
|
||||||
int_min = -(2**64)
|
int_min = -(2**64)
|
||||||
int_max = 2**64
|
int_max = 2**64
|
||||||
float_min = sys.float_info.min
|
float_min = sys.float_info.min
|
||||||
float_max = sys.float_info.max
|
float_max = sys.float_info.max
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def unicode_superscript(n: int) -> str:
|
||||||
|
"""Transform an integer into its unicode-based superscript character."""
|
||||||
|
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Simulation Symbol Names
|
# - Simulation Symbol Names
|
||||||
####################
|
####################
|
||||||
|
_l = ''
|
||||||
|
_it_lower = iter(string.ascii_lowercase)
|
||||||
|
|
||||||
|
|
||||||
class SimSymbolName(enum.StrEnum):
|
class SimSymbolName(enum.StrEnum):
|
||||||
# Lower
|
# Generic
|
||||||
LowerA = enum.auto()
|
Constant = enum.auto()
|
||||||
LowerB = enum.auto()
|
Expr = enum.auto()
|
||||||
LowerC = enum.auto()
|
Data = enum.auto()
|
||||||
LowerD = enum.auto()
|
|
||||||
LowerI = enum.auto()
|
# Ascii Letters
|
||||||
LowerT = enum.auto()
|
while True:
|
||||||
LowerX = enum.auto()
|
try:
|
||||||
LowerY = enum.auto()
|
globals()['_l'] = next(globals()['_it_lower'])
|
||||||
LowerZ = enum.auto()
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
|
||||||
|
locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
|
||||||
|
|
||||||
|
# Greek Letters
|
||||||
|
LowerTheta = enum.auto()
|
||||||
|
LowerPhi = enum.auto()
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
Ex = enum.auto()
|
Ex = enum.auto()
|
||||||
|
@ -64,18 +88,15 @@ class SimSymbolName(enum.StrEnum):
|
||||||
Wavelength = enum.auto()
|
Wavelength = enum.auto()
|
||||||
Frequency = enum.auto()
|
Frequency = enum.auto()
|
||||||
|
|
||||||
Flux = enum.auto()
|
|
||||||
|
|
||||||
PermXX = enum.auto()
|
PermXX = enum.auto()
|
||||||
PermYY = enum.auto()
|
PermYY = enum.auto()
|
||||||
PermZZ = enum.auto()
|
PermZZ = enum.auto()
|
||||||
|
|
||||||
|
Flux = enum.auto()
|
||||||
|
|
||||||
DiffOrderX = enum.auto()
|
DiffOrderX = enum.auto()
|
||||||
DiffOrderY = enum.auto()
|
DiffOrderY = enum.auto()
|
||||||
|
|
||||||
# Generic
|
|
||||||
Expr = enum.auto()
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
|
@ -109,17 +130,21 @@ class SimSymbolName(enum.StrEnum):
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
SSN = SimSymbolName
|
SSN = SimSymbolName
|
||||||
return {
|
return (
|
||||||
# Lower
|
# Ascii Letters
|
||||||
SSN.LowerA: 'a',
|
{SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
|
||||||
SSN.LowerB: 'b',
|
| {
|
||||||
SSN.LowerC: 'c',
|
SSN[f'Upper{letter.upper()}']: letter.upper()
|
||||||
SSN.LowerD: 'd',
|
for letter in string.ascii_lowercase
|
||||||
SSN.LowerI: 'i',
|
}
|
||||||
SSN.LowerT: 't',
|
| {
|
||||||
SSN.LowerX: 'x',
|
# Generic
|
||||||
SSN.LowerY: 'y',
|
SSN.Constant: 'constant',
|
||||||
SSN.LowerZ: 'z',
|
SSN.Expr: 'expr',
|
||||||
|
SSN.Data: 'data',
|
||||||
|
# Greek Letters
|
||||||
|
SSN.LowerTheta: 'theta',
|
||||||
|
SSN.LowerPhi: 'phi',
|
||||||
# Fields
|
# Fields
|
||||||
SSN.Ex: 'Ex',
|
SSN.Ex: 'Ex',
|
||||||
SSN.Ey: 'Ey',
|
SSN.Ey: 'Ey',
|
||||||
|
@ -136,22 +161,35 @@ class SimSymbolName(enum.StrEnum):
|
||||||
# Optics
|
# Optics
|
||||||
SSN.Wavelength: 'wl',
|
SSN.Wavelength: 'wl',
|
||||||
SSN.Frequency: 'freq',
|
SSN.Frequency: 'freq',
|
||||||
SSN.Flux: 'flux',
|
|
||||||
SSN.PermXX: 'eps_xx',
|
SSN.PermXX: 'eps_xx',
|
||||||
SSN.PermYY: 'eps_yy',
|
SSN.PermYY: 'eps_yy',
|
||||||
SSN.PermZZ: 'eps_zz',
|
SSN.PermZZ: 'eps_zz',
|
||||||
|
SSN.Flux: 'flux',
|
||||||
SSN.DiffOrderX: 'order_x',
|
SSN.DiffOrderX: 'order_x',
|
||||||
SSN.DiffOrderY: 'order_y',
|
SSN.DiffOrderY: 'order_y',
|
||||||
# Generic
|
}
|
||||||
SSN.Expr: 'expr',
|
)[self]
|
||||||
}[self]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name_pretty(self) -> str:
|
def name_pretty(self) -> str:
|
||||||
SSN = SimSymbolName
|
SSN = SimSymbolName
|
||||||
return {
|
return {
|
||||||
|
# Generic
|
||||||
|
# Greek Letters
|
||||||
|
SSN.LowerTheta: 'θ',
|
||||||
|
SSN.LowerPhi: 'φ',
|
||||||
|
# Fields
|
||||||
|
SSN.Etheta: 'Eθ',
|
||||||
|
SSN.Ephi: 'Eφ',
|
||||||
|
SSN.Hr: 'Hr',
|
||||||
|
SSN.Htheta: 'Hθ',
|
||||||
|
SSN.Hphi: 'Hφ',
|
||||||
|
# Optics
|
||||||
SSN.Wavelength: 'λ',
|
SSN.Wavelength: 'λ',
|
||||||
SSN.Frequency: '𝑓',
|
SSN.Frequency: '𝑓',
|
||||||
|
SSN.PermXX: 'ε_xx',
|
||||||
|
SSN.PermYY: 'ε_yy',
|
||||||
|
SSN.PermZZ: 'ε_zz',
|
||||||
}.get(self, self.name)
|
}.get(self, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,8 +211,7 @@ def mk_interval(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
class SimSymbol(pyd.BaseModel):
|
||||||
class SimSymbol:
|
|
||||||
"""A declarative representation of a symbolic variable.
|
"""A declarative representation of a symbolic variable.
|
||||||
|
|
||||||
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
|
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
|
||||||
|
@ -183,6 +220,8 @@ class SimSymbol:
|
||||||
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
|
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_config = pyd.ConfigDict(frozen=True)
|
||||||
|
|
||||||
sym_name: SimSymbolName
|
sym_name: SimSymbolName
|
||||||
mathtype: spux.MathType = spux.MathType.Real
|
mathtype: spux.MathType = spux.MathType.Real
|
||||||
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
|
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
|
||||||
|
@ -191,6 +230,9 @@ class SimSymbol:
|
||||||
## -> 'None' indicates that no particular unit has yet been chosen.
|
## -> 'None' indicates that no particular unit has yet been chosen.
|
||||||
## -> Not exposed in the UI; must be set some other way.
|
## -> Not exposed in the UI; must be set some other way.
|
||||||
unit: spux.Unit | None = None
|
unit: spux.Unit | None = None
|
||||||
|
## -> TODO: We currently allowing units that don't match PhysicalType
|
||||||
|
## -> -- In particular, NonPhysical w/units means "unknown units".
|
||||||
|
## -> -- This is essential for the Scientific Constant Node.
|
||||||
|
|
||||||
# Size
|
# Size
|
||||||
## -> All SimSymbol sizes are "2D", but interpreted by convention.
|
## -> All SimSymbol sizes are "2D", but interpreted by convention.
|
||||||
|
@ -205,43 +247,76 @@ class SimSymbol:
|
||||||
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
|
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
|
||||||
## -> See self.domain.
|
## -> See self.domain.
|
||||||
## -> We have to deconstruct symbolic interval semantics a bit for UI.
|
## -> We have to deconstruct symbolic interval semantics a bit for UI.
|
||||||
|
is_constant: bool = False
|
||||||
interval_finite_z: tuple[int, int] = (0, 1)
|
interval_finite_z: tuple[int, int] = (0, 1)
|
||||||
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
|
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
|
||||||
interval_finite_re: tuple[float, float] = (0, 1)
|
interval_finite_re: tuple[float, float] = (0.0, 1.0)
|
||||||
interval_inf: tuple[bool, bool] = (True, True)
|
interval_inf: tuple[bool, bool] = (True, True)
|
||||||
interval_closed: tuple[bool, bool] = (False, False)
|
interval_closed: tuple[bool, bool] = (False, False)
|
||||||
|
|
||||||
interval_finite_im: tuple[float, float] = (0, 1)
|
interval_finite_im: tuple[float, float] = (0.0, 1.0)
|
||||||
interval_inf_im: tuple[bool, bool] = (True, True)
|
interval_inf_im: tuple[bool, bool] = (True, True)
|
||||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Labels
|
||||||
####################
|
####################
|
||||||
@property
|
@functools.cached_property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Usable name for the symbol."""
|
"""Usable name for the symbol."""
|
||||||
return self.sym_name.name
|
return self.sym_name.name
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
def name_pretty(self) -> str:
|
def name_pretty(self) -> str:
|
||||||
"""Pretty (possibly unicode) name for the thing."""
|
"""Pretty (possibly unicode) name for the thing."""
|
||||||
return self.sym_name.name_pretty
|
return self.sym_name.name_pretty
|
||||||
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
|
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
|
def mathtype_size_label(self) -> str:
|
||||||
|
"""Pretty label that shows both mathtype and size."""
|
||||||
|
return f'{self.mathtype.label_pretty}' + (
|
||||||
|
'ˣ'.join([unicode_superscript(out_axis) for out_axis in self.shape])
|
||||||
|
if self.shape
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def unit_label(self) -> str:
|
||||||
|
"""Pretty unit label, which is an empty string when there is no unit."""
|
||||||
|
return spux.sp_to_str(self.unit) if self.unit is not None else ''
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def def_label(self) -> str:
|
||||||
|
"""Pretty definition label, exposing the symbol definition."""
|
||||||
|
return f'{self.name_pretty} | {self.unit_label} ∈ {self.mathtype_size_label}'
|
||||||
|
## TODO: Domain of validity from self.domain?
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
def plot_label(self) -> str:
|
def plot_label(self) -> str:
|
||||||
"""Pretty plot-oriented label."""
|
"""Pretty plot-oriented label."""
|
||||||
return f'{self.name_pretty}' + (
|
return f'{self.name_pretty}' + (
|
||||||
f'({self.unit})' if self.unit is not None else ''
|
f'({self.unit})' if self.unit is not None else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
####################
|
||||||
|
# - Computed Properties
|
||||||
|
####################
|
||||||
|
@functools.cached_property
|
||||||
def unit_factor(self) -> spux.SympyExpr:
|
def unit_factor(self) -> spux.SympyExpr:
|
||||||
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
||||||
return self.unit if self.unit is not None else sp.S(1)
|
return self.unit if self.unit is not None else sp.S(1)
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
|
def size(self) -> tuple[int, ...] | None:
|
||||||
|
return {
|
||||||
|
(1, 1): spux.NumberSize1D.Scalar,
|
||||||
|
(2, 1): spux.NumberSize1D.Vec2,
|
||||||
|
(3, 1): spux.NumberSize1D.Vec3,
|
||||||
|
(4, 1): spux.NumberSize1D.Vec4,
|
||||||
|
}.get((self.rows, self.cols))
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
match (self.rows, self.cols):
|
match (self.rows, self.cols):
|
||||||
case (1, 1):
|
case (1, 1):
|
||||||
|
@ -253,7 +328,12 @@ class SimSymbol:
|
||||||
case (_, _):
|
case (_, _):
|
||||||
return (self.rows, self.cols)
|
return (self.rows, self.cols)
|
||||||
|
|
||||||
@property
|
@functools.cached_property
|
||||||
|
def shape_len(self) -> spux.SympyExpr:
|
||||||
|
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
||||||
|
return len(self.shape)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
def domain(self) -> sp.Interval | sp.Set:
|
def domain(self) -> sp.Interval | sp.Set:
|
||||||
"""Return the scalar domain of valid values for each element of the symbol.
|
"""Return the scalar domain of valid values for each element of the symbol.
|
||||||
|
|
||||||
|
@ -303,11 +383,31 @@ class SimSymbol:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def valid_domain_value(self) -> spux.SympyExpr:
|
||||||
|
"""A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
|
||||||
|
match (self.domain.start.is_finite, self.domain.end.is_finite):
|
||||||
|
case (True, True):
|
||||||
|
if self.mathtype is spux.MathType.Integer:
|
||||||
|
return (self.domain.start + self.domain.end) // 2
|
||||||
|
return (self.domain.start + self.domain.end) / 2
|
||||||
|
|
||||||
|
case (True, False):
|
||||||
|
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||||
|
return self.domain.start + one
|
||||||
|
|
||||||
|
case (False, True):
|
||||||
|
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||||
|
return self.domain.end - one
|
||||||
|
|
||||||
|
case (False, False):
|
||||||
|
return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
@property
|
@functools.cached_property
|
||||||
def sp_symbol(self) -> sp.Symbol:
|
def sp_symbol(self) -> sp.Symbol | sp.ImmutableMatrix:
|
||||||
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
|
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
|
||||||
|
|
||||||
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
|
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
|
||||||
|
@ -352,7 +452,82 @@ class SimSymbol:
|
||||||
elif self.domain.right <= 0:
|
elif self.domain.right <= 0:
|
||||||
mathtype_kwargs |= {'negative': True}
|
mathtype_kwargs |= {'negative': True}
|
||||||
|
|
||||||
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
|
# Scalar: Return Symbol
|
||||||
|
if self.rows == 1 and self.cols == 1:
|
||||||
|
return sp.Symbol(self.sym_name.name, **mathtype_kwargs)
|
||||||
|
|
||||||
|
# Vector|Matrix: Return Matrix of Symbols
|
||||||
|
## -> MatrixSymbol doesn't support assumptions.
|
||||||
|
## -> This little construction does.
|
||||||
|
return sp.ImmutableMatrix(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
sp.Symbol(self.sym_name.name + f'_{row}{col}', **mathtype_kwargs)
|
||||||
|
for col in range(self.cols)
|
||||||
|
]
|
||||||
|
for row in range(self.rows)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def sp_symbol_matsym(self) -> sp.Symbol | sp.MatrixSymbol:
|
||||||
|
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`, w/variable shape support.
|
||||||
|
|
||||||
|
To preserve as many assumptions as possible, `self.sp_symbol` returns a matrix of individual `sp.Symbol`s whenever the `SimSymbol` is non-scalar.
|
||||||
|
However, this isn't always the most useful representation: For example, if the intention is to use a shaped symbolic variable as an argument to `sympy.lambdify()`, one would have to flatten each individual `sp.Symbol` and pass each matrix element as a single element, greatly complicating things like broadcasting.
|
||||||
|
|
||||||
|
For this reason, this property is provided.
|
||||||
|
Whenever the `SimSymbol` is scalar, it works identically to `self.sp_symbol`.
|
||||||
|
However, when the `SimSymbol` is shaped, an appropriate `sp.MatrixSymbol` is returned instead.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
`sp.MatrixSymbol` doesn't support assumptions.
|
||||||
|
As such, things like deduction of `MathType` from expressions involving a matrix symbol simply won't work.
|
||||||
|
"""
|
||||||
|
if self.shape_len == 0:
|
||||||
|
return self.sp_symbol
|
||||||
|
return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def sp_symbol_phy(self) -> spux.SympyExpr:
|
||||||
|
"""Physical symbol containing `self.sp_symbol` multiplied by `self.unit`."""
|
||||||
|
return self.sp_symbol * self.unit_factor
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def expr_info(self) -> dict[str, typ.Any]:
|
||||||
|
"""Generate keyword arguments for an ExprSocket, whose output values will be guaranteed to conform to this `SimSymbol`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Before use, `active_kind=ct.FlowKind.Range` can be added to make the `ExprSocket`.
|
||||||
|
|
||||||
|
Default values are set for both `Value` and `Range`.
|
||||||
|
To this end, `self.domain` is used.
|
||||||
|
|
||||||
|
Since `ExprSocketDef` allows the use of infinite bounds for `default_min` and `default_max`, we defer the decision of how to treat finite-fallback to the `ExprSocketDef`.
|
||||||
|
"""
|
||||||
|
if self.size is not None:
|
||||||
|
if self.unit in self.physical_type.valid_units:
|
||||||
|
return {
|
||||||
|
'output_name': self.sym_name,
|
||||||
|
# Socket Interface
|
||||||
|
'size': self.size,
|
||||||
|
'mathtype': self.mathtype,
|
||||||
|
'physical_type': self.physical_type,
|
||||||
|
# Defaults: Units
|
||||||
|
'default_unit': self.unit,
|
||||||
|
'default_symbols': [],
|
||||||
|
# Defaults: FlowKind.Value
|
||||||
|
'default_value': self.conform(
|
||||||
|
self.valid_domain_value, strip_unit=True
|
||||||
|
),
|
||||||
|
# Defaults: FlowKind.Range
|
||||||
|
'default_min': self.domain.start,
|
||||||
|
'default_max': self.domain.end,
|
||||||
|
}
|
||||||
|
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Operations
|
# - Operations
|
||||||
|
@ -373,7 +548,7 @@ class SimSymbol:
|
||||||
cols=get_attr('cols'),
|
cols=get_attr('cols'),
|
||||||
interval_finite_z=get_attr('interval_finite_z'),
|
interval_finite_z=get_attr('interval_finite_z'),
|
||||||
interval_finite_q=get_attr('interval_finite_q'),
|
interval_finite_q=get_attr('interval_finite_q'),
|
||||||
interval_finite_re=get_attr('interval_finite_q'),
|
interval_finite_re=get_attr('interval_finite_re'),
|
||||||
interval_inf=get_attr('interval_inf'),
|
interval_inf=get_attr('interval_inf'),
|
||||||
interval_closed=get_attr('interval_closed'),
|
interval_closed=get_attr('interval_closed'),
|
||||||
interval_finite_im=get_attr('interval_finite_im'),
|
interval_finite_im=get_attr('interval_finite_im'),
|
||||||
|
@ -381,24 +556,199 @@ class SimSymbol:
|
||||||
interval_closed_im=get_attr('interval_closed_im'),
|
interval_closed_im=get_attr('interval_closed_im'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_finite_domain( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
start: int | float,
|
||||||
|
end: int | float,
|
||||||
|
start_closed: bool = True,
|
||||||
|
end_closed: bool = True,
|
||||||
|
start_im: bool = float,
|
||||||
|
end_im: bool = float,
|
||||||
|
start_closed_im: bool = True,
|
||||||
|
end_closed_im: bool = True,
|
||||||
|
) -> typ.Self:
|
||||||
|
"""Update the symbol with a finite range."""
|
||||||
|
closed_re = (start_closed, end_closed)
|
||||||
|
closed_im = (start_closed_im, end_closed_im)
|
||||||
|
match self.mathtype:
|
||||||
|
case spux.MathType.Integer:
|
||||||
|
return self.update(
|
||||||
|
interval_finite_z=(start, end),
|
||||||
|
interval_inf=(False, False),
|
||||||
|
interval_closed=closed_re,
|
||||||
|
)
|
||||||
|
case spux.MathType.Rational:
|
||||||
|
return self.update(
|
||||||
|
interval_finite_q=(start, end),
|
||||||
|
interval_inf=(False, False),
|
||||||
|
interval_closed=closed_re,
|
||||||
|
)
|
||||||
|
case spux.MathType.Real:
|
||||||
|
return self.update(
|
||||||
|
interval_finite_re=(start, end),
|
||||||
|
interval_inf=(False, False),
|
||||||
|
interval_closed=closed_re,
|
||||||
|
)
|
||||||
|
case spux.MathType.Complex:
|
||||||
|
return self.update(
|
||||||
|
interval_finite_re=(start, end),
|
||||||
|
interval_finite_im=(start_im, end_im),
|
||||||
|
interval_inf=(False, False),
|
||||||
|
interval_closed=closed_re,
|
||||||
|
interval_closed_im=closed_im,
|
||||||
|
)
|
||||||
|
|
||||||
def set_size(self, rows: int, cols: int) -> typ.Self:
|
def set_size(self, rows: int, cols: int) -> typ.Self:
|
||||||
|
return self.update(rows=rows, cols=cols)
|
||||||
|
|
||||||
|
def conform(
|
||||||
|
self, sp_obj: spux.SympyType, strip_unit: bool = False
|
||||||
|
) -> spux.SympyType:
|
||||||
|
"""Conform a sympy object to the properties of this `SimSymbol`, if possible.
|
||||||
|
|
||||||
|
To achieve this, a number of operations may be performed:
|
||||||
|
|
||||||
|
- **Unit Conversion**: If the object has no units, but should, multiply by `self.unit`. If the object has units, but shouldn't, strip them. Otherwise, convert its unit to `self.unit`.
|
||||||
|
- **Broadcast Expansion**: If the object is a scalar, but the `SimSymbol` is shaped, then an `sp.ImmutableMatrix` is returned with the scalar at each position.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A transformed sympy object guaranteed usable as a particular value of this `SimSymbol` variable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the units of `sp_obj` can't be cleanly converted to `self.unit`.
|
||||||
|
"""
|
||||||
|
res = sp_obj
|
||||||
|
|
||||||
|
# Unit Conversion
|
||||||
|
match (spux.uses_units(sp_obj), self.unit is not None):
|
||||||
|
case (True, True):
|
||||||
|
res = spux.scale_to_unit(sp_obj, self.unit) * self.unit
|
||||||
|
|
||||||
|
case (False, True):
|
||||||
|
res = sp_obj * self.unit
|
||||||
|
|
||||||
|
case (True, False):
|
||||||
|
res = spux.strip_unit_system(sp_obj)
|
||||||
|
|
||||||
|
if strip_unit:
|
||||||
|
res = spux.strip_unit_system(sp_obj)
|
||||||
|
|
||||||
|
# Broadcast Expansion
|
||||||
|
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
|
||||||
|
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def scale(
|
||||||
|
self, sp_obj: spux.SympyType, use_jax_array: bool = True
|
||||||
|
) -> int | float | complex | jtyp.Inexact[jtyp.Array, '...']:
|
||||||
|
"""Remove all symbolic elements from the conformed `sp_obj`, preparing it for use in contexts that don't support unrealized symbols.
|
||||||
|
|
||||||
|
On top of `self.conform()`, a number of operations are performed.
|
||||||
|
|
||||||
|
- **Unit Stripping**: The `self.unit` of the expression returned by `self.conform()` will be stripped.
|
||||||
|
- **Sympy to Python**: The now symbol-less expression will be converted to either a pure Python type, or to a `jax` array (if `use_jax_array` is set).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
When creating numerical functions of expressions using `.lambdify`, `self.scale()` **must be used** in place of `self.conform()` before the parameterized expression is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A "raw" (pure Python / jax array) type guaranteed usable as a particular **numerical** value of this `SymSymbol` variable.
|
||||||
|
"""
|
||||||
|
# Conform
|
||||||
|
res = self.conform(sp_obj)
|
||||||
|
|
||||||
|
# Strip Units
|
||||||
|
res = spux.scale_to_unit(sp_obj, self.unit)
|
||||||
|
|
||||||
|
# Sympy to Python
|
||||||
|
res = spux.sympy_to_python(res, use_jax_array=use_jax_array)
|
||||||
|
|
||||||
|
return res # noqa: RET504
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_expr(
|
||||||
|
sym_name: SimSymbolName,
|
||||||
|
expr: spux.SympyExpr,
|
||||||
|
unit_expr: spux.SympyExpr,
|
||||||
|
) -> typ.Self:
|
||||||
|
"""Deduce a `SimSymbol` that matches the output of a given expression (and unit expression).
|
||||||
|
|
||||||
|
This is an essential method, allowing for the ded
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
`PhysicalType` **cannot be set** from an expression in the generic sense.
|
||||||
|
Therefore, the trick of using `NonPhysical` with non-`None` unit to denote unknown `PhysicalType` is used in the output.
|
||||||
|
|
||||||
|
All intervals are kept at their defaults.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sym_name: The `SimSymbolName` to set to the resulting symbol.
|
||||||
|
expr: The unit-aware expression to parse and encapsulate as a symbol.
|
||||||
|
unit_expr: A dimensional analysis expression (set to `1` to make the resulting symbol unitless).
|
||||||
|
Fundamentally, units are just the variables of scalar terms.
|
||||||
|
'1' for unitless terms are, in the dimanyl sense, constants.
|
||||||
|
|
||||||
|
Doing it like this may be a little messy, but is accurate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A fresh new `SimSymbol` that tries to match the given expression (and unit expression) well enough to be usable in place of it.
|
||||||
|
"""
|
||||||
|
# MathType from Expr Assumptions
|
||||||
|
## -> All input symbols have assumptions, because we are very pedantic.
|
||||||
|
## -> Therefore, we should be able to reconstruct the MathType.
|
||||||
|
mathtype = spux.MathType.from_expr(expr)
|
||||||
|
|
||||||
|
# PhysicalType as "NonPhysical"
|
||||||
|
## -> 'unit' still applies - but we can't guarantee a PhysicalType will.
|
||||||
|
## -> Therefore, this is what we gotta do.
|
||||||
|
physical_type = spux.PhysicalType.NonPhysical
|
||||||
|
|
||||||
|
# Rows/Cols from Expr (if Matrix)
|
||||||
|
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
|
||||||
|
|
||||||
return SimSymbol(
|
return SimSymbol(
|
||||||
sym_name=self.sym_name,
|
sym_name=sym_name,
|
||||||
mathtype=self.mathtype,
|
mathtype=mathtype,
|
||||||
physical_type=self.physical_type,
|
physical_type=physical_type,
|
||||||
unit=self.unit,
|
unit=unit_expr if unit_expr != 1 else None,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
cols=cols,
|
cols=cols,
|
||||||
interval_finite_z=self.interval_finite_z,
|
|
||||||
interval_finite_q=self.interval_finite_q,
|
|
||||||
interval_finite_re=self.interval_finite_re,
|
|
||||||
interval_inf=self.interval_inf,
|
|
||||||
interval_closed=self.interval_closed,
|
|
||||||
interval_finite_im=self.interval_finite_im,
|
|
||||||
interval_inf_im=self.interval_inf_im,
|
|
||||||
interval_closed_im=self.interval_closed_im,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Serialization
|
||||||
|
####################
|
||||||
|
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
|
||||||
|
"""Transforms this `SimSymbol` into an object that can be natively serialized by `msgspec`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Makes use of `pydantic.BaseModel.model_dump()` to cast any special fields into a serializable format.
|
||||||
|
If this method is failing, check that `pydantic` can actually cast all the fields in your model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A particular `list`, with two elements:
|
||||||
|
|
||||||
|
1. The `serialize`-provided "Type Identifier", to differentiate this list from generic list.
|
||||||
|
2. A dictionary containing simple Python types, as cast by `pydantic`.
|
||||||
|
"""
|
||||||
|
return [serialize.TypeID.SimSymbol, self.__class__.__name__, self.model_dump()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
|
||||||
|
"""Transforms an object made by `self.dump_as_msgspec()` into an instance of `SimSymbol`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The method presumes that the deserialized object produced by `msgspec` perfectly matches the object originally created by `self.dump_as_msgspec()`.
|
||||||
|
|
||||||
|
This is a **mostly robust** presumption, as `pydantic` attempts to be quite consistent in how to interpret types with almost identical semantics.
|
||||||
|
Still, yet-unknown edge cases may challenge these presumptions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of `SimSymbol`, initialized using the `model_dump()` dictionary.
|
||||||
|
"""
|
||||||
|
return SimSymbol(**obj[2])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Common Sim Symbols
|
# - Common Sim Symbols
|
||||||
|
@ -453,14 +803,10 @@ class CommonSimSymbol(enum.StrEnum):
|
||||||
Wavelength = enum.auto()
|
Wavelength = enum.auto()
|
||||||
Frequency = enum.auto()
|
Frequency = enum.auto()
|
||||||
|
|
||||||
DiffOrderX = enum.auto()
|
|
||||||
DiffOrderY = enum.auto()
|
|
||||||
|
|
||||||
Flux = enum.auto()
|
Flux = enum.auto()
|
||||||
|
|
||||||
WaveVecX = enum.auto()
|
DiffOrderX = enum.auto()
|
||||||
WaveVecY = enum.auto()
|
DiffOrderY = enum.auto()
|
||||||
WaveVecZ = enum.auto()
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
|
@ -549,10 +895,10 @@ class CommonSimSymbol(enum.StrEnum):
|
||||||
if eh == 'e'
|
if eh == 'e'
|
||||||
else spux.PhysicalType.HField,
|
else spux.PhysicalType.HField,
|
||||||
unit=unit,
|
unit=unit,
|
||||||
interval_finite_re=(0, sys.float_info.max),
|
interval_finite_re=(0, float_max),
|
||||||
interval_inf_re=(False, True),
|
interval_inf_re=(False, True),
|
||||||
interval_closed_re=(True, False),
|
interval_closed_re=(True, False),
|
||||||
interval_finite_im=(sys.float_info.min, sys.float_info.max),
|
interval_finite_im=(float_min, float_max),
|
||||||
interval_inf_im=(True, True),
|
interval_inf_im=(True, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -575,7 +921,7 @@ class CommonSimSymbol(enum.StrEnum):
|
||||||
sym_name=self.name,
|
sym_name=self.name,
|
||||||
physical_type=spux.PhysicalType.Time,
|
physical_type=spux.PhysicalType.Time,
|
||||||
unit=unit,
|
unit=unit,
|
||||||
interval_finite_re=(0, sys.float_info.max),
|
interval_finite_re=(0, float_max),
|
||||||
interval_inf=(False, True),
|
interval_inf=(False, True),
|
||||||
interval_closed=(True, False),
|
interval_closed=(True, False),
|
||||||
),
|
),
|
||||||
|
@ -592,19 +938,13 @@ class CommonSimSymbol(enum.StrEnum):
|
||||||
CSS.FieldHr: sym_field('h'),
|
CSS.FieldHr: sym_field('h'),
|
||||||
CSS.FieldHtheta: sym_field('h'),
|
CSS.FieldHtheta: sym_field('h'),
|
||||||
CSS.FieldHphi: sym_field('h'),
|
CSS.FieldHphi: sym_field('h'),
|
||||||
CSS.Flux: SimSymbol(
|
|
||||||
sym_name=SimSymbolName.Flux,
|
|
||||||
mathtype=spux.MathType.Real,
|
|
||||||
physical_type=spux.PhysicalType.Power,
|
|
||||||
unit=unit,
|
|
||||||
),
|
|
||||||
# Optics
|
# Optics
|
||||||
CSS.Wavelength: SimSymbol(
|
CSS.Wavelength: SimSymbol(
|
||||||
sym_name=self.name,
|
sym_name=self.name,
|
||||||
mathtype=spux.MathType.Real,
|
mathtype=spux.MathType.Real,
|
||||||
physical_type=spux.PhysicalType.Length,
|
physical_type=spux.PhysicalType.Length,
|
||||||
unit=unit,
|
unit=unit,
|
||||||
interval_finite=(0, sys.float_info.max),
|
interval_finite=(0, float_max),
|
||||||
interval_inf=(False, True),
|
interval_inf=(False, True),
|
||||||
interval_closed=(False, False),
|
interval_closed=(False, False),
|
||||||
),
|
),
|
||||||
|
@ -613,10 +953,30 @@ class CommonSimSymbol(enum.StrEnum):
|
||||||
mathtype=spux.MathType.Real,
|
mathtype=spux.MathType.Real,
|
||||||
physical_type=spux.PhysicalType.Freq,
|
physical_type=spux.PhysicalType.Freq,
|
||||||
unit=unit,
|
unit=unit,
|
||||||
interval_finite=(0, sys.float_info.max),
|
interval_finite=(0, float_max),
|
||||||
interval_inf=(False, True),
|
interval_inf=(False, True),
|
||||||
interval_closed=(False, False),
|
interval_closed=(False, False),
|
||||||
),
|
),
|
||||||
|
CSS.Flux: SimSymbol(
|
||||||
|
sym_name=SimSymbolName.Flux,
|
||||||
|
mathtype=spux.MathType.Real,
|
||||||
|
physical_type=spux.PhysicalType.Power,
|
||||||
|
unit=unit,
|
||||||
|
),
|
||||||
|
CSS.DiffOrderX: SimSymbol(
|
||||||
|
sym_name=self.name,
|
||||||
|
mathtype=spux.MathType.Integer,
|
||||||
|
interval_finite=(int_min, int_max),
|
||||||
|
interval_inf=(True, True),
|
||||||
|
interval_closed=(False, False),
|
||||||
|
),
|
||||||
|
CSS.DiffOrderY: SimSymbol(
|
||||||
|
sym_name=self.name,
|
||||||
|
mathtype=spux.MathType.Integer,
|
||||||
|
interval_finite=(int_min, int_max),
|
||||||
|
interval_inf=(True, True),
|
||||||
|
interval_closed=(False, False),
|
||||||
|
),
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue