refactor: revamped symbolic flow (inaccurate unit conversions)

main
Sofus Albert Høgsbro Rose 2024-05-24 16:01:23 +02:00
parent 353a2c997e
commit bcba444a8b
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
28 changed files with 2397 additions and 1020 deletions

View File

@ -18,8 +18,8 @@ from .array import ArrayFlow
from .capabilities import CapabilitiesFlow from .capabilities import CapabilitiesFlow
from .flow_kinds import FlowKind from .flow_kinds import FlowKind
from .info import InfoFlow from .info import InfoFlow
from .lazy_range import RangeFlow, ScalingMode
from .lazy_func import FuncFlow from .lazy_func import FuncFlow
from .lazy_range import RangeFlow, ScalingMode
from .params import ParamsFlow from .params import ParamsFlow
from .value import ValueFlow from .value import ValueFlow

View File

@ -50,11 +50,6 @@ class ArrayFlow:
#################### ####################
# - Computed Properties # - Computed Properties
#################### ####################
@property
def is_symbolic(self) -> bool:
"""Always False, as ArrayFlows are never unrealized."""
return False
def __len__(self) -> int: def __len__(self) -> int:
"""Outer length of the contained array.""" """Outer length of the contained array."""
return len(self.values) return len(self.values)
@ -196,5 +191,11 @@ class ArrayFlow:
""" """
return self.rescale(lambda v: v, new_unit=new_unit) return self.rescale(lambda v: v, new_unit=new_unit)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self: def rescale_to_unit_system(self, unit_system: spux.UnitSystem | None) -> typ.Self:
raise NotImplementedError if unit_system is None:
return self.values
return self.correct_unit(None).rescale(
lambda v: spux.scale_to_unit_system(v * self.unit, unit_system),
new_unit=spux.convert_to_unit_system(self.unit, unit_system),
)

View File

@ -16,6 +16,7 @@
import dataclasses import dataclasses
import typing as typ import typing as typ
from types import MappingProxyType
from ..socket_types import SocketType from ..socket_types import SocketType
from .flow_kinds import FlowKind from .flow_kinds import FlowKind
@ -25,6 +26,7 @@ from .flow_kinds import FlowKind
class CapabilitiesFlow: class CapabilitiesFlow:
socket_type: SocketType socket_type: SocketType
active_kind: FlowKind active_kind: FlowKind
allow_out_to_in: dict[FlowKind, FlowKind] = dataclasses.field(default_factory=dict)
is_universal: bool = False is_universal: bool = False
@ -40,7 +42,13 @@ class CapabilitiesFlow:
def is_compatible_with(self, other: typ.Self) -> bool: def is_compatible_with(self, other: typ.Self) -> bool:
return other.is_universal or ( return other.is_universal or (
self.socket_type == other.socket_type self.socket_type == other.socket_type
and self.active_kind == other.active_kind and (
self.active_kind == other.active_kind
or (
other.active_kind in other.allow_out_to_in
and self.active_kind == other.allow_out_to_in[other.active_kind]
)
)
# == Constraint # == Constraint
and all( and all(
name in other.must_match name in other.must_match

View File

@ -67,8 +67,9 @@ class InfoFlow:
default_factory=dict default_factory=dict
) )
# Access
@functools.cached_property @functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None: def first_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension. """The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array. Can be used to index `.shape` of the represented raw array.
@ -87,13 +88,24 @@ class InfoFlow:
return list(self.dims.keys())[-1] return list(self.dims.keys())[-1]
return None return None
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int: def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
if idx > 0 and idx < len(self.dims) - 1:
return list(self.dims.keys())[idx]
return None
def dim_by_name(self, dim_name: str) -> int:
"""The integer axis occupied by the dimension. """The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array. Can be used to index `.shape` of the represented raw array.
""" """
return list(self.dims.keys()).index(dim) dims_with_name = [dim for dim in self.dims if dim.name == dim_name]
if len(dims_with_name) == 1:
return dims_with_name[0]
msg = f'Dim name {dim_name} not found in InfoFlow (or >1 found)'
raise ValueError(msg)
# Information By-Dim
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool: def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array. """Whether the dim's index is continuous, and therefore index array.
@ -114,6 +126,23 @@ class InfoFlow:
return isinstance(self.dims[dim], list) return isinstance(self.dims[dim], list)
return False return False
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim has explicitly uniform indexing.
This is needed primarily to check whether a Fourier Transform can be meaningfully performed on the data over the dimension's axis.
In practice, we've decided that only `RangeFlow` really truly _guarantees_ uniform indexing.
While `ArrayFlow` may be uniform in practice, it's a very expensive to check, and it's far better to enforce that the user perform that check and opt for a `RangeFlow` instead, at the time of dimension definition.
"""
return isinstance(self.dims[dim], RangeFlow) and self.dims[dim].scaling == 'lin'
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
#################### ####################
# - Output: Contravariant Value # - Output: Contravariant Value
#################### ####################
@ -128,6 +157,49 @@ class InfoFlow:
default_factory=dict default_factory=dict
) )
####################
# - Properties
####################
@functools.cached_property
def input_mathtypes(self) -> tuple[spux.MathType, ...]:
return tuple([dim.mathtype for dim in self.dims])
@functools.cached_property
def output_mathtypes(self) -> tuple[spux.MathType, int, int]:
return [self.output.mathtype for _ in range(len(self.output.shape) + 1)]
@functools.cached_property
def order(self) -> tuple[spux.MathType, ...]:
r"""The order of the tensor represented by this info.
While that sounds fancy and all, it boils down to:
$$
\texttt{dims} + |\texttt{output}.\texttt{shape}|
$$
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
return len(self.input_mathtypes) + self.output_shape_len
####################
# - Properties
####################
@functools.cached_property
def dim_labels(self) -> dict[str, dict[str, str]]:
"""Return a dictionary mapping pretty dim names to information oriented for columnar information display."""
return {
dim.name_pretty: {
'length': str(len(dim_idx)) if dim_idx is not None else '',
'mathtype': dim.mathtype.label_pretty,
'unit': dim.unit_label,
}
for dim, dim_idx in self.dims.items()
}
#################### ####################
# - Operations: Dimensions # - Operations: Dimensions
#################### ####################
@ -147,9 +219,11 @@ class InfoFlow:
"""Slice a dimensional array by-index along a particular dimension.""" """Slice a dimensional array by-index along a particular dimension."""
return InfoFlow( return InfoFlow(
dims={ dims={
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]] _dim: (
if _dim == dim dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
else _dim if _dim == dim
else dim_idx
)
for _dim, dim_idx in self.dims.items() for _dim, dim_idx in self.dims.items()
}, },
output=self.output, output=self.output,
@ -166,7 +240,7 @@ class InfoFlow:
return InfoFlow( return InfoFlow(
dims={ dims={
(new_dim if _dim == old_dim else _dim): ( (new_dim if _dim == old_dim else _dim): (
new_dim_idx if _dim == old_dim else _dim new_dim_idx if _dim == old_dim else dim_idx
) )
for _dim, dim_idx in self.dims.items() for _dim, dim_idx in self.dims.items()
}, },
@ -235,6 +309,26 @@ class InfoFlow:
pinned_values=self.pinned_values, pinned_values=self.pinned_values,
) )
def operate_output(
self,
other: typ.Self,
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
) -> spux.SympyExpr:
if self.dims == other.dims:
sym_name = sim_symbols.SimSymbolName.Expr
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
return InfoFlow(
dims=self.dims,
output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
pinned_values=self.pinned_values,
)
msg = f'InfoFlow: operate_output cannot be used when dimensions are not identical ({self.dims} | {other.dims}).'
raise ValueError(msg)
#################### ####################
# - Operations: Fold # - Operations: Fold
#################### ####################

View File

@ -22,7 +22,7 @@ from types import MappingProxyType
import jax import jax
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .params import ParamsFlow from .params import ParamsFlow
@ -244,17 +244,10 @@ class FuncFlow:
""" """
func: LazyFunction func: LazyFunction
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field( func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
default_factory=list func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=dict default_factory=dict
) )
## TODO: Use SimSymbol instead of the MathType|PT union.
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
## -- SimSymbol has more semantic meaning, including a name.
## -- If desired, SimSymbols could maybe even require a specific unit.
## It could greatly simplify a whole lot of pain associated with func_args.
supports_jax: bool = False supports_jax: bool = False
#################### ####################
@ -315,17 +308,18 @@ class FuncFlow:
def realize( def realize(
self, self,
params: ParamsFlow, params: ParamsFlow,
unit_system: spux.UnitSystem | None = None, symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), {}
),
) -> typ.Self: ) -> typ.Self:
if self.supports_jax: if self.supports_jax:
return self.func_jax( return self.func_jax(
*params.scaled_func_args(unit_system, symbol_values), *params.scaled_func_args(self.func_args, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values), *params.scaled_func_kwargs(self.func_args, symbol_values),
) )
return self.func( return self.func(
*params.scaled_func_args(unit_system, symbol_values), *params.scaled_func_args(self.func_kwargs, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values), *params.scaled_func_kwargs(self.func_kwargs, symbol_values),
) )
#################### ####################

View File

@ -18,17 +18,18 @@ import dataclasses
import enum import enum
import functools import functools
import typing as typ import typing as typ
from fractions import Fraction
from types import MappingProxyType from types import MappingProxyType
import jax.numpy as jnp import jax.numpy as jnp
import jaxtyping as jtyp import jaxtyping as jtyp
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow from .array import ArrayFlow
from .lazy_func import FuncFlow
log = logger.get(__name__) log = logger.get(__name__)
@ -62,7 +63,7 @@ class ScalingMode(enum.StrEnum):
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class RangeFlow: class RangeFlow:
r"""Represents a spaced array using symbolic boundary expressions. r"""Represents a finite spaced array using symbolic boundary expressions.
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
@ -79,33 +80,76 @@ class RangeFlow:
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated. Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes: Attributes:
start: An expression generating a scalar, unitless, complex value for the array's lower bound. start: An expression representing the unitless part of the finite, scalar, complex value for the array's lower bound.
_Integer, rational, and real values are also supported._ _Integer, rational, and real values are also supported._
stop: An expression generating a scalar, unitless, complex value for the array's upper bound. start: An expression representing the unitless part of the finite, scalar, complex value for the array's upper bound.
_Integer, rational, and real values are also supported._ _Integer, rational, and real values are also supported._
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`. steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
scaling: The method of distributing `step` values between the two endpoints. scaling: The method of distributing `step` values between the two endpoints.
Generally, the linear default is sufficient.
unit: The unit of the generated array values unit: The unit to interpret the values as.
symbols: Set of variables from which `start` and/or `stop` are determined. symbols: Set of variables from which `start` and/or `stop` are determined.
""" """
start: spux.ScalarUnitlessComplexExpr start: spux.ScalarUnitlessComplexExpr
stop: spux.ScalarUnitlessComplexExpr stop: spux.ScalarUnitlessComplexExpr
steps: int steps: int = 0
scaling: ScalingMode = ScalingMode.Lin scaling: ScalingMode = ScalingMode.Lin
unit: spux.Unit | None = None unit: spux.Unit | None = None
symbols: frozenset[spux.Symbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
# Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
#################### ####################
# - Computed Properties # - SimSymbol Interop
####################
@staticmethod
def from_sym(
sym: sim_symbols.SimSymbol,
steps: int = 50,
scaling: ScalingMode | str = ScalingMode.Lin,
) -> typ.Self:
if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
use_steps = 0
else:
use_steps = steps
return RangeFlow(
start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
steps=use_steps,
scaling=ScalingMode(scaling),
unit=sym.unit,
)
def to_sym(
self,
sym_name: sim_symbols.SimSymbolName,
) -> typ.Self:
physical_type = spux.PhysicalType.from_unit(self.unit, optional=True)
return sim_symbols.SimSymbol(
sym_name=sym_name,
mathtype=self.mathtype,
physical_type=(
physical_type
if physical_type is not None
else spux.PhysicalType.NonPhysical
),
unit=self.unit,
rows=1,
cols=1,
).set_domain(start=self.realize_start(), end=self.realize_end())
####################
# - Symbols
#################### ####################
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
The order is guaranteed to be **deterministic**. The order is guaranteed to be **deterministic**.
@ -115,10 +159,21 @@ class RangeFlow:
""" """
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
@property @functools.cached_property
def is_symbolic(self) -> bool: def sorted_sp_symbols(self) -> list[spux.Symbol]:
"""Whether the `RangeFlow` has unrealized symbols.""" """Computes `sympy` symbols from `self.sorted_symbols`.
return len(self.symbols) > 0
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol for sym in self.sorted_symbols]
####################
# - Properties
####################
@functools.cached_property
def unit_factor(self) -> spux.SympyExpr:
return self.unit if self.unit is not None else sp.S(1)
def __len__(self) -> int: def __len__(self) -> int:
"""Compute the length of the array that would be realized. """Compute the length of the array that would be realized.
@ -166,6 +221,14 @@ class RangeFlow:
#################### ####################
# - Methods # - Methods
#################### ####################
@property
def ideal_midpoint(self) -> spux.SympyExpr:
return (self.stop + self.start) / 2
@property
def ideal_range(self) -> spux.SympyExpr:
return self.stop - self.start
def rescale( def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self: ) -> typ.Self:
@ -181,8 +244,8 @@ class RangeFlow:
new_pre_start = self.start if not reverse else self.stop new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit) new_start = rescale_func(new_pre_start * self.unit_factor)
new_stop = rescale_func(new_pre_stop * self.unit) new_stop = rescale_func(new_pre_stop * self.unit_factor)
return RangeFlow( return RangeFlow(
start=( start=(
@ -204,6 +267,99 @@ class RangeFlow:
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError raise NotImplementedError
@functools.cached_property
def bound_fourier_transform(self):
r"""Treat this `RangeFlow` it as an axis along which a fourier transform is being performed, such that its bounds scale according to the Nyquist Limit.
# Sampling Theory
In general, the Fourier Transform is an operator that works on infinite, periodic, continuous functions.
In this context alone it is a ideal transform (in terms of information retention), one which degrades quite gracefully in the face of practicalities like windowing (allowing them to apply analytically to non-periodic functions too).
While often used to transform an axis of time to an axis of frequency, in general this transform "simply" extracts all repeating structures from a function.
This is illustrated beautifully in the way that the output unit becomes the reciprocal of the input unit, which is the theory underlying why we say that measurements recieved as a reciprocal unit are in a "reciprocal space" (also called "k-space").
The real world is not so nice, of course, and as such we must generally make do with the Discrete Fourier Transform.
Even with bounded discrete information, we can annoy many mathematicians by defining a DFT in such a way that "structure per thing" ($\frac{1}{\texttt{unit}}$) still makes sense to us (to them, maybe not).
A DFT can still only retain the information given to it, but so long as we have enough "original structure", any "repeating structure" should be extractable with sufficient clarity to be useful.
What "sufficient clarity" means is the basis for the entire field of "sampling theory".
The theoretical maximum for the "fineness of repetition" that is "noticeable" in the fourier-transformed of some data is characterized by a theoretical upper bound called the Nyquist Frequency / Limit, which ends up being half of the sampling rate.
Thus, to determine bounds on the data, use of the Nyquist Limit is generally a good starting point.
Of course, when the discrete data comes from a discretization of a continuous signal, information from higher frequencies might still affect the discrete results.
They do little else than cause havoc, though - best causing noise, and at worst causing structured artifacts (sometimes called "aliasing").
Some of the first innovations in sampling theory were related to "anti-aliasing" filters, whose sole purpose is to try to remove informational frequencies above the Nyquist Limit of the discrete sensor.
In FDTD simulation, we're generally already ahead when it comes to aliasing, since our field values come from an already-discrete process.
That is, unless we start overly "binning" (averaging over $n$ discrete observations); in this case, care should be taken to make sure that interesting results aren't merely unfortunately structured aliasing artifacts.
For more, see <https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem>
# Implementation
In practice, our goal in `RangeFlow` is to compute the bounds of the index array along the fourier-transformed data axis.
The reciprocal of the unit will be taken (when unitless, `1/1=`).
The raw Nyquist Limit $n_q$ will be used to bound the unitless part of the output as $[-n_q, n_q]$
Raises:
ValueError: If `self.scaling` is not linear, since the FT can only be performed on uniformly spaced data.
"""
if self.scaling is ScalingMode.Lin:
nyquist_limit = self.steps / self.ideal_range
# Return New Bounds w/Nyquist Theorem
## -> The Nyquist Limit describes "max repeated info per sample".
## -> Information can still record "faster" than the Nyquist Limit.
## -> It will just be either noise (best case), or banded artifacts.
## -> This is called "aliasing", and it's best to try and filter it.
## -> Sims generally "bin" yee cells, which sacrifices some NyqLim.
return RangeFlow(
start=-nyquist_limit,
stop=nyquist_limit,
scaling=self.scaling,
unit=1 / self.unit if self.unit is not None else None,
pre_fourier_ideal_midpoint=self.ideal_midpoint,
)
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
raise ValueError(msg)
@functools.cached_property
def bound_inv_fourier_transform(self):
r"""Treat this `RangeFlow` as an axis along which an inverse fourier transform is being performed, such that its Nyquist-limit bounds are transformed back into values along the original unit dimension.
See `self.bound_fourier_transform` for the theoretical concepts.
Notes:
**The discrete inverse fourier transform always centers its output at $0$**.
Of course, it's entirely probable that the original signal was not centered at $0$.
For this reason, when performing a Fourier transform, `self.bound_fourier_transform` sets a special variable, `self.pre_fourier_ideal_midpoint`.
When set, it will retain the `self.ideal_midpoint` around which both `self.start` and `self.stop` should be centered after an inverse FT.
If `self.pre_fourier_ideal_midpoint` is set, then it will be used as the midpoint of the output's `start`/`stop`.
Otherwise, $0$ will be used - in which case the user should themselves, manually, shift the output if needed.
"""
if self.scaling is ScalingMode.Lin:
orig_ideal_range = self.steps / self.ideal_range
orig_start_centered = -orig_ideal_range
orig_stop_centered = orig_ideal_range
orig_ideal_midpoint = (
self.pre_fourier_ideal_midpoint
if self.pre_fourier_ideal_midpoint is not None
else sp.S(0)
)
# Return New Bounds w/Inverse of Nyquist Theorem
return RangeFlow(
start=-orig_start_centered + orig_ideal_midpoint,
stop=orig_stop_centered + orig_ideal_midpoint,
scaling=self.scaling,
unit=1 / self.unit if self.unit is not None else None,
)
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
raise ValueError(msg)
#################### ####################
# - Exporters # - Exporters
#################### ####################
@ -237,15 +393,15 @@ class RangeFlow:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`. """Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes: Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols. The ordering of the symbols is identical to `self.sorted_symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns: Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`, A function that generates a 1D numerical array equivalent to the range represented in this `RangeFlow`.
""" """
# Compile JAX Functions for Start/End Expressions # Compile JAX Functions for Start/End Expressions
## -> FYI, JAX-in-JAX works perfectly fine. ## -> FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax') start_jax = sp.lambdify(self.sorted_sp_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax') stop_jax = sp.lambdify(self.sorted_sp_symbols, self.stop, 'jax')
# Compile ArrayGen Function # Compile ArrayGen Function
def gen_array( def gen_array(
@ -256,54 +412,80 @@ class RangeFlow:
# Return ArrayGen Function # Return ArrayGen Function
return gen_array return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
#################### ####################
# - Realization # - Realization
#################### ####################
def realize_symbols(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
"""Realize **all** input symbols to the `RangeFlow`.
Parameters:
symbol_values: A scalar, unitless, complex `sympy` expression for each symbol defined in `self.symbols`.
Returns:
A dictionary directly usable in expression substitutions using `sp.Basic.subs()`.
"""
if self.symbols == set(symbol_values.keys()):
realized_syms = {}
for sym in self.sorted_symbols:
sym_value = symbol_values[sym]
# Sympy Expression
## -> We need to conform the expression to the SimSymbol.
## -> Mainly, this is
if (
isinstance(sym_value, spux.SympyType)
and not isinstance(sym_value, sp.MatrixBase)
and not spux.uses_units(sym_value)
):
v = sym.conform(sym_value)
else:
msg = f'RangeFlow: No realization support for symbolic value {sym_value} (type={type(sym_value)})'
raise NotImplementedError(msg)
realized_syms |= {sym: v}
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
def realize_start( def realize_start(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex: ) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol.""" """Realize the start-bound by inserting particular values for each symbol."""
return spux.sympy_to_python( realized_symbols = self.realize_symbols(symbol_values)
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols}) return spux.sympy_to_python(self.start.subs(realized_symbols))
)
def realize_stop( def realize_stop(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex: ) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol.""" """Realize the stop-bound by inserting particular values for each symbol."""
return spux.sympy_to_python( realized_symbols = self.realize_symbols(symbol_values)
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols}) return spux.sympy_to_python(self.stop.subs(realized_symbols))
)
def realize_step_size( def realize_step_size(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex: ) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol.""" """Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin: if self.scaling is not ScalingMode.Lin:
raise NotImplementedError('Non-linear scaling mode not yet suported') msg = 'Non-linear scaling mode not yet suported'
raise NotImplementedError(msg)
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps raw_step_size = (
self.realize_stop(symbol_values) - self.realize_start(symbol_values) + 1
) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer(): if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size) return int(raw_step_size)
@ -311,7 +493,9 @@ class RangeFlow:
def realize( def realize(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> ArrayFlow: ) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array. """Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
@ -324,15 +508,33 @@ class RangeFlow:
## TODO: Check symbol values for coverage. ## TODO: Check symbol values for coverage.
return ArrayFlow( return ArrayFlow(
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]), values=self.as_func(
*[
spux.scale_to_unit_system(symbol_values[sym])
for sym in self.sorted_symbols
]
),
unit=self.unit, unit=self.unit,
is_sorted=True, is_sorted=True,
) )
@functools.cached_property @functools.cached_property
def realize_array(self) -> ArrayFlow: def realize_array(self) -> ArrayFlow:
"""Standardized access to `self.realize()` when there are no symbols.""" """Standardized access to `self.realize()` when there are no symbols.
return self.realize()
Raises:
ValueError: If there are symbols defined in `self.symbols`.
"""
if not self.symbols:
return self.realize()
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
raise ValueError(msg)
@property
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
"""Alias for `realize_array.values`."""
return self.realize_array.values
def __getitem__(self, subscript: slice): def __getitem__(self, subscript: slice):
"""Implement indexing and slicing in a sane way. """Implement indexing and slicing in a sane way.
@ -379,23 +581,14 @@ class RangeFlow:
Raises: Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct. ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
""" """
if self.unit is not None: return RangeFlow(
log.debug( start=self.start,
'%s: Corrected unit to %s', stop=self.stop,
self, steps=self.steps,
corrected_unit, scaling=self.scaling,
) unit=corrected_unit,
return RangeFlow( symbols=self.symbols,
start=self.start, )
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self: def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds. """Replaces the unit, **with** rescaling of the bounds.
@ -410,11 +603,6 @@ class RangeFlow:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct. ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
""" """
if self.unit is not None: if self.unit is not None:
log.debug(
'%s: Scaled to unit %s',
self,
unit,
)
return RangeFlow( return RangeFlow(
start=spux.scale_to_unit(self.start * self.unit, unit), start=spux.scale_to_unit(self.start * self.unit, unit),
stop=spux.scale_to_unit(self.stop * self.unit, unit), stop=spux.scale_to_unit(self.stop * self.unit, unit),
@ -423,11 +611,18 @@ class RangeFlow:
unit=unit, unit=unit,
symbols=self.symbols, symbols=self.symbols,
) )
return RangeFlow(
start=self.start * unit,
stop=self.stop * unit,
steps=self.steps,
scaling=self.scaling,
unit=unit,
symbols=self.symbols,
)
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' def rescale_to_unit_system(
raise ValueError(msg) self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
"""Replaces the units, **with** rescaling of the bounds. """Replaces the units, **with** rescaling of the bounds.
Parameters: Parameters:
@ -439,28 +634,11 @@ class RangeFlow:
Raises: Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct. ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
""" """
if self.unit is not None: return RangeFlow(
log.debug( start=spux.scale_to_unit_system(self.start * self.unit, unit_system),
'%s: Scaled to new unit system (new unit = %s)', stop=spux.scale_to_unit_system(self.stop * self.unit, unit_system),
self, steps=self.steps,
unit_system[spux.PhysicalType.from_unit(self.unit)], scaling=self.scaling,
) unit=spux.convert_to_unit_system(self.unit, unit_system),
return RangeFlow( symbols=self.symbols,
start=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
stop=spux.strip_unit_system(
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
unit_system,
),
steps=self.steps,
scaling=self.scaling,
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
symbols=self.symbols,
)
msg = (
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
) )
raise ValueError(msg)

View File

@ -17,15 +17,19 @@
import dataclasses import dataclasses
import functools import functools
import typing as typ import typing as typ
from fractions import Fraction
from types import MappingProxyType from types import MappingProxyType
import jaxtyping as jtyp
import sympy as sp import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .expr_info import ExprInfo from .expr_info import ExprInfo
from .flow_kinds import FlowKind from .flow_kinds import FlowKind
from .lazy_range import RangeFlow
# from .info import InfoFlow # from .info import InfoFlow
@ -34,13 +38,22 @@ log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow: class ParamsFlow:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
####################
# - Symbols
####################
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns: Returns:
@ -48,52 +61,179 @@ class ParamsFlow:
""" """
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
@functools.cached_property
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
"""Computes `sympy` symbols from `self.sorted_symbols`.
When the output is shaped, a single shaped symbol (`sp.MatrixSymbol`) is used to represent the symbolic name and shaping.
This choice is made due to `MatrixSymbol`'s compatibility with `.lambdify` JIT.
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
####################
# - JIT'ed Callables for Numerical Function Arguments
####################
def func_args_n(
self, target_syms: list[sim_symbols.SimSymbol]
) -> list[
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
]
]:
"""Callable functions for evaluating each `self.func_args` entry numerically.
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `target_syms`.
Notes:
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
This ensures conformance to the `SimSymbol` properties (like units), as well as adherance to a numerical type identity compatible with `sp.lambdify()`.
Parameters:
target_syms: `SimSymbol`s describing how a particular `ParamsFlow` function argument should be scaled when performing a purely numerical insertion.
"""
return [
sp.lambdify(
self.sorted_sp_symbols,
target_sym.conform(func_arg, strip_unit=True),
'jax',
)
for func_arg, target_sym in zip(self.func_args, target_syms, strict=True)
]
def func_kwargs_n(
self, target_syms: dict[str, sim_symbols.SimSymbol]
) -> dict[
str,
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
],
]:
"""Callable functions for evaluating each `self.func_kwargs` entry numerically.
The arguments of each function **must** be pre-treated using `SimSymbol.scale()`.
This ensures conformance to the `SimSymbol` properties, as well as adherance to a numerical type identity compatible with `sp.lambdify()`
"""
return {
func_arg_key: sp.lambdify(
self.sorted_sp_symbols,
target_syms[func_arg_key].scale(func_arg),
'jax',
)
for func_arg_key, func_arg in self.func_kwargs.items()
}
####################
# - Realization
####################
def realize_symbols(
self,
symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}),
) -> dict[
sp.Symbol,
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
]:
"""Fully realize all symbols by assigning them a value.
Three kinds of values for `symbol_values` are supported, fundamentally:
- **Sympy Expression**: When the value is a sympy expression with units, the unit of the `SimSymbol` key which unit the value if converted to.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
- **Range**: When the value is a `RangeFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
- **Array**: When the value is an `ArrayFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
Returns:
A dictionary almost with `.subs()`, other than `jax` arrays.
"""
if set(self.symbols) == set(symbol_values.keys()):
realized_syms = {}
for sym in self.sorted_symbols:
sym_value = symbol_values[sym]
if isinstance(sym_value, spux.SympyType):
v = sym.scale(sym_value)
elif isinstance(sym_value, ArrayFlow | RangeFlow):
v = sym_value.rescale_to_unit(sym.unit).values
## NOTE: RangeFlow must not be symbolic.
else:
msg = f'No support for symbolic value {sym_value} (type={type(sym_value)})'
raise NotImplementedError(msg)
realized_syms |= {sym: v}
return realized_syms
msg = f'ParamsFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
#################### ####################
# - Realize Arguments # - Realize Arguments
#################### ####################
def scaled_func_args( def scaled_func_args(
self, self,
unit_system: spux.UnitSystem | None = None, target_syms: list[sim_symbols.SimSymbol] = (),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{} {}
), ),
): ) -> list[
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`. int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
"""Realize correctly conformed numerical arguments for `self.func_args`.
For all `arg`s in `self.func_args`, the following operations are performed. Because we allow symbols to be used in `self.func_args`, producing a numerical value that can be passed directly to a `FuncFlow` becomes a two-step process:
Notes: 1. Conform Symbols: Arbitrary `sympy` expressions passed as `symbol_values` must first be conformed to match the ex. units of `SimSymbol`s found in `self.symbols`, before they can be used.
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
2. Conform Function Arguments: Arbitrary `sympy` expressions encoded in `self.func_args` must, **after** inserting the conformed numerical symbols, themselves be conformed to the expected ex. units of the function that they are to be used within.
**`ParamsFlow` doesn't contain information about the `SimSymbol`s that `self.func_args` are expected to conform to** (on purpose).
Therefore, the user is required to pass a `target_syms` with identical length to `self.func_args`, describing the `SimSymbol`s to conform the function arguments to.
Our implementation attempts to utilize simple, powerful primitives to accomplish this in roughly three steps:
1. **Realize Symbols**: Particular passed symbolic values `symbol_values`, which are arbitrary `sympy` expressions, are conformed to the definitions in `self.symbols` (ex. to match units), then cast to numerical values (pure Python / jax array).
2. **Lazy Function Arguments**: Stored function arguments `self.func_args`, which are arbitrary `sympy` expressions, are conformed to the definitions in `target_syms` (ex. to match units), then cast to numerical values (pure Python / jax array).
_Technically, this happens as part of `self.func_args_n`._
3. **Numerical Evaluation**: The numerical values for each symbol are passed as parameters to each (callable) element of `self.func_args_n`, which produces a correct numerical value for each function argument.
Parameters:
target_syms: `SimSymbol`s describing how the function arguments returned by this method are intended to be used.
**Generally**, the parallel `FuncFlow.func_args` should be inserted here, and guarantees correct results when this output is inserted into `FuncFlow.func(...)`.
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `target_syms`).
""" """
if not all(sym in self.symbols for sym in symbol_values): realized_symbols = list(self.realize_symbols(symbol_values).values())
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
return [ return [
( func_arg_n(*realized_symbols)
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True) for func_arg_n in self.func_args_n(target_syms)
if arg not in symbol_values
else symbol_values[arg]
)
for arg in self.func_args
] ]
def scaled_func_kwargs( def scaled_func_kwargs(
self, self,
unit_system: spux.UnitSystem | None = None, target_syms: list[sim_symbols.SimSymbol] = (),
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
): ) -> dict[
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
if not all(sym in self.symbols for sym in symbol_values): ]:
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" """Realize correctly conformed numerical arguments for `self.func_kwargs`.
raise ValueError(msg)
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
"""
realized_symbols = self.realize_symbols(symbol_values)
return { return {
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True) func_arg_name: func_arg_n(**realized_symbols)
if arg not in symbol_values for func_arg_name, func_arg_n in self.func_kwargs_n(target_syms).items()
else symbol_values[arg]
for arg_name, arg in self.func_kwargs.items()
} }
#################### ####################
@ -129,8 +269,8 @@ class ParamsFlow:
#################### ####################
# - Generate ExprSocketDef # - Generate ExprSocketDef
#################### ####################
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]: def sym_expr_infos(self, use_range: bool = False) -> dict[str, ExprInfo]:
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`. """Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`. Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing. The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
@ -151,35 +291,22 @@ class ParamsFlow:
The `ExprInfo`s can be directly defererenced `**expr_info`) The `ExprInfo`s can be directly defererenced `**expr_info`)
""" """
for sim_sym in self.sorted_symbols: for sym in self.sorted_symbols:
if use_range and sim_sym.mathtype is spux.MathType.Complex: if use_range and sym.mathtype is spux.MathType.Complex:
msg = 'No support for complex range in ExprInfo' msg = 'No support for complex range in ExprInfo'
raise NotImplementedError(msg) raise NotImplementedError(msg)
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1): if use_range and (sym.rows > 1 or sym.cols > 1):
msg = 'No support for non-scalar elements of range in ExprInfo' msg = 'No support for non-scalar elements of range in ExprInfo'
raise NotImplementedError(msg) raise NotImplementedError(msg)
if sim_sym.rows > 3 or sim_sym.cols > 1: if sym.rows > 3 or sym.cols > 1:
msg = 'No support for >Vec3 / Matrix values in ExprInfo' msg = 'No support for >Vec3 / Matrix values in ExprInfo'
raise NotImplementedError(msg) raise NotImplementedError(msg)
return { return {
sim_sym.name: { sym.name: {
# Declare Kind/Size
## -> Kind: Value prevents user-alteration of config.
## -> Size: Always scalar, since symbols are scalar (for now).
'active_kind': FlowKind.Value if not use_range else FlowKind.Range, 'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
'size': spux.NumberSize1D.Scalar,
# Declare MathType/PhysicalType
## -> MathType: Lookup symbol name in info dimensions.
## -> PhysicalType: Same.
'mathtype': self.dims[sim_sym].mathtype,
'physical_type': self.dims[sim_sym].physical_type,
# TODO: Default Value
# FlowKind.Value: Default Value
#'default_value':
# FlowKind.Range: Default Min/Max/Steps
'default_min': sim_sym.domain.start,
'default_max': sim_sym.domain.end,
'default_steps': 50, 'default_steps': 50,
} }
for sim_sym in self.sorted_symbols | sym.expr_info
for sym in self.sorted_symbols
} }

View File

@ -29,6 +29,7 @@ import tidy3d as td
from blender_maxwell.contracts import BLEnumElement from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud from blender_maxwell.services import tdcloud
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow from .flow_kinds.info import InfoFlow
@ -482,6 +483,93 @@ class DataFileFormat(enum.StrEnum):
## - When sidecars aren't found, the user would "fill in the blanks". ## - When sidecars aren't found, the user would "fill in the blanks".
## - ...Thus achieving the same result as if there were a sidecar. ## - ...Thus achieving the same result as if there were a sidecar.
####################
# - Functions: DataFrame
####################
@staticmethod
def to_df(
data: jtyp.Shaped[jtyp.Array, 'x_size y_size'], info: InfoFlow
) -> pl.DataFrame:
"""Utility method to convert raw data to a `polars.DataFrame`, as guided by an `InfoFlow`.
Only works with 2D data (obviously).
Raises:
ValueError: If the data has more than two dimensions, all `info` dimensions are not discrete/labelled, or the dimensionality of `info` doesn't match.
"""
if info.order > 2: # noqa: PLR2004
msg = f'Data may not have more than two dimensions (info={info}, data.shape={data.shape})'
raise ValueError(msg)
if any(info.has_idx_cont(dim) for dim in info.dims):
msg = f'To convert data|info to a dataframe, no dimensions can have continuous indices (info={info})'
raise ValueError(msg)
data_np = np.array(data)
MT = spux.MathType
match (
info.input_mathtypes,
info.output.mathtype,
info.output.rows,
info.output.cols,
):
# (R,Z) -> Complex Scalar
## -> Polars (also pandas) doesn't have a complex type.
## -> Will be treated as (R, Z, 2) -> Real Scalar.
case ((MT.Rational | MT.Real, MT.Integer), MT.Complex, 1, 1):
row_dim = info.first_dim
col_dim = info.last_dim
return pl.DataFrame(
{row_dim.name: info.dims[row_dim]}
| {
col_label + postfix: re_im(data_np[:, col])
for col, col_label in enumerate(info.dims[col_dim])
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
}
)
# (R,Z) -> Scalar
case ((MT.Rational | MT.Real, MT.Integer), _, 1, 1):
row_dim = info.first_dim
col_dim = info.last_dim
return pl.DataFrame(
{row_dim.name: info.dims[row_dim]}
| {
col_label: data_np[:, col]
for col, col_label in enumerate(info.dims[col_dim])
}
)
# (Z) -> Complex Vector/Covector
case ((MT.Integer,), MT.Complex, r, c) if (r > 1 and c == 1) or (
r == 1 and c > 1
):
col_dim = info.last_dim
return pl.DataFrame(
{
col_label + postfix: re_im(data_np[col, :])
for col, col_label in enumerate(info.dims[col_dim])
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
}
)
# (Z) -> Real Vector
## -> Each integer index will be treated as a column index.
## -> This will effectively transpose the data.
case ((MT.Integer,), _, r, c) if (r > 1 and c == 1) or (r == 1 and c > 1):
col_dim = info.last_dim
return pl.DataFrame(
{
col_label: data_np[col, :]
for col, col_label in enumerate(info.dims[col_dim])
}
)
#################### ####################
# - Functions: Saver # - Functions: Saver
#################### ####################
@ -506,50 +594,7 @@ class DataFileFormat(enum.StrEnum):
np.savetxt(path, data) np.savetxt(path, data)
def save_csv(path, data, info): def save_csv(path, data, info):
data_np = np.array(data) df = self.to_df(data, info)
# Extract Input Coordinates
dim_columns = {
dim.name: np.array(dim_idx.realize_array)
for i, (dim, dim_idx) in enumerate(info.dims)
} ## TODO: realize_array might not be defined on some index arrays
# Declare Function to Extract Output Values
output_columns = {}
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
nonlocal output_columns
# Complex: Split to Two Columns
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
if bool(np.any(np.iscomplex(data_col))):
output_columns |= {
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
}
# Else: Use Array Directly
else:
output_columns |= {
f'{info.output.name}{output_idx_str}': data_col,
}
## TODO: Maybe a check to ensure dtype!=object?
# Extract Output Values
## -> 2D: Iterate over columns by-index.
## -> 1D: Declare the array as the only column.
if len(data_np.shape) == 2:
for output_idx in data_np.shape[1]:
declare_output_col(data_np[:, output_idx], output_idx, True)
else:
declare_output_col(data_np)
# Compute DataFrame & Write CSV
df = pl.DataFrame(dim_columns | output_columns)
log.debug('Writing Polars DataFrame to CSV:')
log.debug(df)
df.write_csv(path) df.write_csv(path)
def save_npy(path, data, info): def save_npy(path, data, info):

View File

@ -264,17 +264,22 @@ class ManagedBLImage(base.ManagedObj):
# times = [time.perf_counter()] # times = [time.perf_counter()]
# Compute Plot Dimensions # Compute Plot Dimensions
aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = ( # aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
self.gen_image_geometry(width_inches, height_inches, dpi) # self.gen_image_geometry(width_inches, height_inches, dpi)
) # )
# times.append(['Image Geometry', time.perf_counter() - times[0]]) # times.append(['Image Geometry', time.perf_counter() - times[0]])
# log.critical(
# [aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px]
# )
# Create MPL Figure, Axes, and Compute Figure Geometry # Create MPL Figure, Axes, and Compute Figure Geometry
fig, canvas, ax = image_ops.mpl_fig_canvas_ax( # fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
_width_inches, _height_inches, _dpi # _width_inches, _height_inches, _dpi
) # )
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]]) # times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
# fig.clear()
ax.clear() ax.clear()
# times.append(['Clear Axis', time.perf_counter() - times[0]]) # times.append(['Clear Axis', time.perf_counter() - times[0]])

View File

@ -46,6 +46,7 @@ class FilterOperation(enum.StrEnum):
""" """
# Slice # Slice
Slice = enum.auto()
SliceIdx = enum.auto() SliceIdx = enum.auto()
# Pin # Pin
@ -53,9 +54,8 @@ class FilterOperation(enum.StrEnum):
Pin = enum.auto() Pin = enum.auto()
PinIdx = enum.auto() PinIdx = enum.auto()
# Reinterpret # Dimension
Swap = enum.auto() Swap = enum.auto()
SetDim = enum.auto()
#################### ####################
# - UI # - UI
@ -65,14 +65,14 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation FO = FilterOperation
return { return {
# Slice # Slice
FO.SliceIdx: 'a[...]', FO.Slice: '=a[i:j]',
FO.SliceIdx: '≈a[v₁:v₂]',
# Pin # Pin
FO.PinLen1: 'pinₐ =1', FO.PinLen1: 'pinₐ',
FO.Pin: 'pinₐ ≈v', FO.Pin: 'pinₐ ≈v',
FO.PinIdx: 'pinₐ =a[v]', FO.PinIdx: 'pinₐ =i',
# Reinterpret # Reinterpret
FO.Swap: 'a₁ ↔ a₂', FO.Swap: 'a₁ ↔ a₂',
FO.SetDim: 'setₐ =v',
}[value] }[value]
@staticmethod @staticmethod
@ -118,11 +118,6 @@ class FilterOperation(enum.StrEnum):
if len(info.dims) >= 2: # noqa: PLR2004 if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap) operations.append(FO.Swap)
## SetDim
## -> There must be a dimension to correct.
if info.dims:
operations.append(FO.SetDim)
return operations return operations
#################### ####################
@ -145,6 +140,7 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation FO = FilterOperation
return { return {
# Slice # Slice
FO.Slice: 1,
FO.SliceIdx: 1, FO.SliceIdx: 1,
# Pin # Pin
FO.PinLen1: 1, FO.PinLen1: 1,
@ -152,40 +148,35 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: 1, FO.PinIdx: 1,
# Reinterpret # Reinterpret
FO.Swap: 2, FO.Swap: 2,
FO.SetDim: 1,
}[self] }[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation FO = FilterOperation
match self: match self:
case FO.SliceIdx | FO.Swap: # Slice
return info.dims case FO.Slice:
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
# PinLen1: Only allow dimensions with length=1. case FO.SliceIdx:
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
# Pin
case FO.PinLen1: case FO.PinLen1:
return [ return [
dim dim
for dim, dim_idx in info.dims.items() for dim, dim_idx in info.dims.items()
if dim_idx is not None and len(dim_idx) == 1 if not info.has_idx_cont(dim) and len(dim_idx) == 1
] ]
# Pin: Only allow dimensions with discrete index. case FO.Pin:
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too? return info.dims
case FO.Pin | FO.PinIdx:
return [
dim
for dim, dim_idx in info.dims
if dim_idx is not None and len(dim_idx) > 0
]
case FO.SetDim: case FO.PinIdx:
return [ return [dim for dim in info.dims if not info.has_idx_cont(dim)]
dim
for dim, dim_idx in info.dims # Dimension
if dim_idx is not None case FO.Swap:
and not isinstance(dim_idx, list) return info.dims
and dim_idx.mathtype == spux.MathType.Integer
]
return [] return []
@ -193,9 +184,14 @@ class FilterOperation(enum.StrEnum):
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool: ) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information.""" """Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or ( if self.num_dim_inputs == 1:
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info) return dim_0 in self.valid_dims(info)
)
if self.num_dim_inputs == 2: # noqa: PLR2004
valid_dims = self.valid_dims(info)
return dim_0 in valid_dims and dim_1 in valid_dims
return False
#################### ####################
# - UI # - UI
@ -209,6 +205,9 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation FO = FilterOperation
return { return {
# Pin # Pin
FO.Slice: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
FO.SliceIdx: lambda expr: jlax.slice_in_dim( FO.SliceIdx: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0 expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
), ),
@ -216,9 +215,8 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0), FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0), FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Reinterpret # Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1), FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
FO.SetDim: lambda expr: expr,
}[self] }[self]
def transform_info( def transform_info(
@ -228,10 +226,10 @@ class FilterOperation(enum.StrEnum):
dim_1: sim_symbols.SimSymbol, dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None, pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None, slice_tuple: tuple[int, int, int] | None = None,
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
): ):
FO = FilterOperation FO = FilterOperation
return { return {
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin # Pin
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
@ -239,7 +237,6 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret # Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
}[self]() }[self]()
@ -330,8 +327,8 @@ class FilterMathNode(base.MaxwellSimNode):
def search_dims(self) -> list[ct.BLEnumElement]: def search_dims(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None and self.operation is not None: if self.expr_info is not None and self.operation is not None:
return [ return [
(dim_name, dim_name, dim_name, '', i) (dim.name, dim.name_pretty, dim.name, '', i)
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info)) for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
] ]
return [] return []
@ -380,8 +377,6 @@ class FilterMathNode(base.MaxwellSimNode):
# Reinterpret # Reinterpret
case FO.Swap: case FO.Swap:
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]' return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
case FO.SetDim:
return f'Filter: Set [{self.active_dim_0}]'
case _: case _:
return self.bl_label return self.bl_label
@ -480,30 +475,6 @@ class FilterMathNode(base.MaxwellSimNode):
) )
} }
# Loose Sockets: Set Dim
## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis.
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Dim')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Func
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.mathtype != dim.mathtype
or current_bl_socket.physical_type != dim.physical_type
):
self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
physical_type=dim.physical_type,
mathtype=dim.mathtype,
default_unit=dim.unit,
show_func_ui=False,
show_info_columns=True,
)
}
# No Loose Value: Remove Input Sockets # No Loose Value: Remove Input Sockets
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
@ -570,60 +541,11 @@ class FilterMathNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
# Dim (Op.SetDim)
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
has_dim_func = not ct.FlowSignal.check(dim_func)
has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info)
# Dimension(s) # Dimension(s)
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
slice_tuple = props['slice_tuple'] slice_tuple = props['slice_tuple']
if has_info and operation is not None: if has_info and operation is not None:
# Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim:
new_dim = (
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
)
if (
dim_0 is not None
and new_dim is not None
and has_dim_info
and has_dim_params
# Check New Dimension Index Array Sizing
and len(dim_info.dims) == 1
and dim_info.output.rows == 1
and dim_info.output.cols == 1
# Check Lack of Params Symbols
and not dim_params.symbols
# Check Expr Dim | New Dim Compatibility
and info.has_idx_discrete(dim_0)
and dim_info.has_idx_discrete(new_dim)
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
):
# Retrieve Dimension Coordinate Array
## -> It must be strictly compatible.
values = dim_func.realize(dim_params, spux.UNITS_SI)
# Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced.
new_dim_idx = ct.ArrayFlow(
values=values,
unit=spux.convert_to_unit_system(
dim_info.output.unit, spux.UNITS_SI
),
).rescale_to_unit(dim_info.output.unit)
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
return operation.transform_info(
info, dim_0, dim_1, replaced_dim=replaced_dim
)
return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple) return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -496,7 +496,7 @@ class MapMathNode(base.MaxwellSimNode):
) )
def search_operations(self) -> list[ct.BLEnumElement]: def search_operations(self) -> list[ct.BLEnumElement]:
if self.info is not None: if self.expr_info is not None:
return [ return [
operation.bl_enum_element(i) operation.bl_enum_element(i)
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info)) for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))

View File

@ -20,8 +20,11 @@ import typing as typ
import bpy import bpy
import jax.numpy as jnp import jax.numpy as jnp
import sympy as sp import sympy as sp
import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
@ -37,37 +40,47 @@ class BinaryOperation(enum.StrEnum):
"""Valid operations for the `OperateMathNode`. """Valid operations for the `OperateMathNode`.
Attributes: Attributes:
Add: Addition w/broadcasting. Mul: Scalar multiplication.
Sub: Subtraction w/broadcasting. Div: Scalar division.
Mul: Hadamard-product multiplication. Pow: Scalar exponentiation.
Div: Hadamard-product based division. Add: Elementwise addition.
Pow: Elementwise expontiation. Sub: Elementwise subtraction.
Atan2: Quadrant-respecting arctangent variant. HadamMul: Elementwise multiplication (hadamard product).
VecVecDot: Dot product for vectors. HadamPow: Principled shape-aware exponentiation (hadamard power).
Cross: Cross product. Atan2: Quadrant-respecting 2D arctangent.
MatVecDot: Matrix-Vector dot product. VecVecDot: Dot product for identically shaped vectors w/transpose.
Cross: Cross product between identically shaped 3D vectors.
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system. LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system. LsqSolve: Minimize error of an underdetermined linear system.
MatMatDot: Matrix-Matrix dot product. VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
""" """
# Number | Number # Number | Number
Add = enum.auto()
Sub = enum.auto()
Mul = enum.auto() Mul = enum.auto()
Div = enum.auto() Div = enum.auto()
Pow = enum.auto() Pow = enum.auto()
# Elements | Elements
Add = enum.auto()
Sub = enum.auto()
HadamMul = enum.auto()
# HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
Atan2 = enum.auto() Atan2 = enum.auto()
# Vector | Vector # Vector | Vector
VecVecDot = enum.auto() VecVecDot = enum.auto()
Cross = enum.auto() Cross = enum.auto()
VecVecOuter = enum.auto()
# Matrix | Vector # Matrix | Vector
MatVecDot = enum.auto()
LinSolve = enum.auto() LinSolve = enum.auto()
LsqSolve = enum.auto() LsqSolve = enum.auto()
# Vector | Matrix
VecMatOuter = enum.auto()
# Matrix | Matrix # Matrix | Matrix
MatMatDot = enum.auto() MatMatDot = enum.auto()
@ -79,19 +92,24 @@ class BinaryOperation(enum.StrEnum):
BO = BinaryOperation BO = BinaryOperation
return { return {
# Number | Number # Number | Number
BO.Mul: ' · r',
BO.Div: ' / r',
BO.Pow: ' ^ r',
# Elements | Elements
BO.Add: ' + r', BO.Add: ' + r',
BO.Sub: ' - r', BO.Sub: ' - r',
BO.Mul: ' ⊙ r', ## Notation for Hadamard Product BO.HadamMul: '𝐋𝐑',
BO.Div: ' / r', # BO.HadamPow: '𝐥 ⊙^ 𝐫',
BO.Pow: 'ℓʳ', BO.Atan2: 'atan2(:x, r:y)',
BO.Atan2: 'atan2(,r)',
# Vector | Vector # Vector | Vector
BO.VecVecDot: '𝐥 · 𝐫', BO.VecVecDot: '𝐥 · 𝐫',
BO.Cross: 'cross(L,R)', BO.Cross: 'cross(𝐥,𝐫)',
BO.VecVecOuter: '𝐥𝐫',
# Matrix | Vector # Matrix | Vector
BO.MatVecDot: '𝐋 · 𝐫',
BO.LinSolve: '𝐋 𝐫', BO.LinSolve: '𝐋 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂', BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂',
# Vector | Matrix
BO.VecMatOuter: '𝐋𝐫',
# Matrix | Matrix # Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑', BO.MatMatDot: '𝐋 · 𝐑',
}[value] }[value]
@ -118,56 +136,104 @@ class BinaryOperation(enum.StrEnum):
"""Deduce valid binary operations from the shapes of the inputs.""" """Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation BO = BinaryOperation
ops_number_number = [ ops_el_el = [
BO.Add, BO.Add,
BO.Sub, BO.Sub,
BO.Mul, BO.HadamMul,
BO.Div, # BO.HadamPow,
BO.Pow,
BO.Atan2,
] ]
match (info_l.output_shape_len, info_r.output_shape_len): outl = info_l.output
outr = info_r.output
match (outl.shape_len, outr.shape_len):
# match (ol.shape_len, info_r.output.shape_len):
# Number | * # Number | *
## Number | Number ## Number | Number
case (0, 0): case (0, 0):
return ops_number_number ops = [
BO.Add,
BO.Sub,
BO.Mul,
BO.Div,
BO.Pow,
]
if (
info_l.output.physical_type == spux.PhysicalType.Length
and info_l.output.unit == info_r.output.unit
):
ops += [BO.Atan2]
return ops
## Number | Vector ## Number | Vector
## -> Broadcasting allows Number|Number ops to work as-is.
case (0, 1): case (0, 1):
return ops_number_number return [BO.Mul] # , BO.HadamPow]
## Number | Matrix ## Number | Matrix
## -> Broadcasting allows Number|Number ops to work as-is.
case (0, 2): case (0, 2):
return ops_number_number return [BO.Mul] # , BO.HadamPow]
# Vector | * # Vector | *
## Vector | Number ## Vector | Number
case (1, 0): case (1, 0):
return ops_number_number return [BO.Mul] # , BO.HadamPow]
## Vector | Number ## Vector | Vector
case (1, 1): case (1, 1):
return [*ops_number_number, BO.VecVecDot, BO.Cross] ops = []
# Vector | Vector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows > outl.cols and outr.rows > outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Covector | Vector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
if outl.rows < outl.cols and outr.rows > outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Vector | Covector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
## -> These are both the same operation, in this case.
if outl.rows > outl.cols and outr.rows < outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Covector | Covector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows < outl.cols and outr.rows < outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Cross Product
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
if (outl.rows == 3 and outr.rows == 3) or (
outl.cols == 3 and outl.cols == 3
):
ops += [BO.Cross]
return ops
## Vector | Matrix ## Vector | Matrix
case (1, 2): case (1, 2):
return [] return [BO.VecMatOuter]
# Matrix | * # Matrix | *
## Matrix | Number ## Matrix | Number
case (2, 0): case (2, 0):
return [*ops_number_number, BO.MatMatDot] return [BO.Mul] # , BO.HadamPow]
## Matrix | Vector ## Matrix | Vector
case (2, 1): case (2, 1):
return [BO.MatVecDot, BO.LinSolve, BO.LsqSolve] prepend_ops = []
# Mat-Vec Dot: Enforce RHS Column Vector
if outr.rows > outl.cols:
prepend_ops += [BO.MatMatDot]
return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
## Matrix | Matrix ## Matrix | Matrix
case (2, 2): case (2, 2):
return [*ops_number_number, BO.MatMatDot] return [*ops_el_el, BO.MatMatDot]
return [] return []
@ -182,34 +248,86 @@ class BinaryOperation(enum.StrEnum):
## TODO: Make this compatible with sp.Matrix inputs ## TODO: Make this compatible with sp.Matrix inputs
return { return {
# Number | Number # Number | Number
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.Mul: lambda exprs: exprs[0] * exprs[1], BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1], BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1], BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
def unit_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
BO.Add: BO.Add.sp_func,
BO.Sub: BO.Sub.sp_func,
BO.HadamMul: BO.Mul.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
BO.VecVecDot: BO.Mul.sp_func,
BO.Cross: BO.Mul.sp_func,
BO.VecVecOuter: BO.Mul.sp_func,
# Matrix | Vector
## -> A,b in Ax = b have units, and the equality must hold.
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
# Vector | Matrix
BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self] }[self]
@property @property
def jax_func(self): def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs.""" """Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
## TODO: Scale the units of one side to the other.
BO = BinaryOperation BO = BinaryOperation
return { return {
# Number | Number # Number | Number
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.Mul: lambda exprs: exprs[0] * exprs[1], BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1], BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1], BO.Pow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]), # Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
# BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector # Vector | Vector
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]), BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]), BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Vector # Matrix | Vector
BO.MatVecDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]), BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]), BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix # Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]), BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self] }[self]
@ -218,30 +336,12 @@ class BinaryOperation(enum.StrEnum):
# - InfoFlow Transform # - InfoFlow Transform
#################### ####################
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow): def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
BO = BinaryOperation """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
return info_l.operate_output(
info_largest = ( info_r,
info_l if info_l.output_shape_len > info_l.output_shape_len else info_l lambda a, b: self.sp_func([a, b]),
lambda a, b: self.unit_func([a, b]),
) )
info_any = info_largest
return {
# Number | * or * | Number
BO.Add: info_largest,
BO.Sub: info_largest,
BO.Mul: info_largest,
BO.Div: info_largest,
BO.Pow: info_largest,
BO.Atan2: info_largest,
# Vector | Vector
BO.VecVecDot: info_any,
BO.Cross: info_any,
# Matrix | Vector
BO.MatVecDot: info_r,
BO.LinSolve: info_r,
BO.LsqSolve: info_r,
# Matrix | Matrix
BO.MatMatDot: info_any,
}[self]
#################### ####################
@ -367,9 +467,9 @@ class OperateMathNode(base.MaxwellSimNode):
# Compute Sympy Function # Compute Sympy Function
## -> The operation enum directly provides the appropriate function. ## -> The operation enum directly provides the appropriate function.
if has_expr_l_value and has_expr_r_value and operation is not None: if has_expr_l_value and has_expr_r_value and operation is not None:
operation.sp_func([expr_l, expr_r]) return operation.sp_func([expr_l, expr_r])
return ct.Flowsignal.FlowPending return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -396,7 +496,7 @@ class OperateMathNode(base.MaxwellSimNode):
## -> The operation enum directly provides the appropriate function. ## -> The operation enum directly provides the appropriate function.
if has_expr_l and has_expr_r: if has_expr_l and has_expr_r:
return (expr_l | expr_r).compose_within( return (expr_l | expr_r).compose_within(
operation.jax_func, enclosing_func=operation.jax_func,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -17,13 +17,13 @@
"""Declares `TransformMathNode`.""" """Declares `TransformMathNode`."""
import enum import enum
import functools
import typing as typ import typing as typ
import bpy import bpy
import jax.numpy as jnp import jax.numpy as jnp
import jaxtyping as jtyp import jaxtyping as jtyp
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
@ -51,6 +51,9 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform # Covariant Transform
FreqToVacWL = enum.auto() FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto() VacWLToFreq = enum.auto()
ConvertIdxUnit = enum.auto()
SetIdxUnit = enum.auto()
FirstColToFirstIdx = enum.auto()
# Fold # Fold
IntDimToComplex = enum.auto() IntDimToComplex = enum.auto()
@ -58,8 +61,8 @@ class TransformOperation(enum.StrEnum):
DimsToMat = enum.auto() DimsToMat = enum.auto()
# Fourier # Fourier
FFT1D = enum.auto() FT1D = enum.auto()
InvFFT1D = enum.auto() InvFT1D = enum.auto()
# TODO: Affine # TODO: Affine
## TODO ## TODO
@ -74,17 +77,22 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform # Covariant Transform
TO.FreqToVacWL: '𝑓 → λᵥ', TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓', TO.VacWLToFreq: 'λᵥ → 𝑓',
TO.ConvertIdxUnit: 'Convert Dim',
TO.SetIdxUnit: 'Set Dim',
TO.FirstColToFirstIdx: '1st Col → Dim',
# Fold # Fold
TO.IntDimToComplex: '', TO.IntDimToComplex: '',
TO.DimToVec: '→ Vector', TO.DimToVec: '→ Vector',
TO.DimsToMat: '→ Matrix', TO.DimsToMat: '→ Matrix',
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier # Fourier
TO.FFT1D: 't → 𝑓', TO.FT1D: '𝑓',
TO.InvFFT1D: '𝑓 t', TO.InvFT1D: '𝑓',
}[value] }[value]
@staticmethod @staticmethod
def to_icon(value: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
return '' return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement: def bl_enum_element(self, i: int) -> ct.BLEnumElement:
@ -98,121 +106,216 @@ class TransformOperation(enum.StrEnum):
) )
#################### ####################
# - Ops from Shape # - Methods
#################### ####################
@property
def num_dim_inputs(self) -> None:
"""The number of axes that should be passed as inputs to `func_jax` when evaluating it.
Especially useful for `ParamFlow`, when deciding whether to pass an integer-axis argument based on a user-selected dimension.
"""
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: 1,
TO.VacWLToFreq: 1,
TO.ConvertIdxUnit: 1,
TO.SetIdxUnit: 1,
TO.FirstColToFirstIdx: 0,
# Fold
TO.IntDimToComplex: 0,
TO.DimToVec: 0,
TO.DimsToMat: 0,
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier
TO.FT1D: 1,
TO.InvFT1D: 1,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
match self:
case TO.FreqToVacWL | TO.FT1D:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Freq
]
case TO.VacWLToFreq | TO.InvFT1D:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Length
]
case TO.ConvertIdxUnit | TO.SetIdxUnit:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
## ColDimToComplex: Implicit Last Dimension
## DimToVec: Implicit Last Dimension
## DimsToMat: Implicit Last 2 Dimensions
case TO.FT1D | TO.InvFT1D:
# Filter by Axis Uniformity
## -> FT requires uniform axis (aka. must be RangeFlow).
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
return []
@staticmethod @staticmethod
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]: def by_info(info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation TO = TransformOperation
operations = [] operations = []
# Covariant Transform # Covariant Transform
## Freq <-> VacWL ## Freq -> VacWL
for dim in info.dims: if TO.FreqToVacWL.valid_dims(info):
if dim.physical_type == spux.PhysicalType.Freq: operations += [TO.FreqToVacWL]
operations.append(TO.FreqToVacWL)
if dim.physical_type == spux.PhysicalType.Freq: ## VacWL -> Freq
operations.append(TO.VacWLToFreq) if TO.VacWLToFreq.valid_dims(info):
operations += [TO.VacWLToFreq]
## Convert Index Unit
if TO.ConvertIdxUnit.valid_dims(info):
operations += [TO.ConvertIdxUnit]
if TO.SetIdxUnit.valid_dims(info):
operations += [TO.SetIdxUnit]
## Column to First Index (Array)
if (
len(info.dims) == 2 # noqa: PLR2004
and info.first_dim.mathtype is spux.MathType.Integer
and info.last_dim.mathtype is spux.MathType.Integer
and info.output.shape_len == 0
):
operations += [TO.FirstColToFirstIdx]
# Fold # Fold
## (Last) Int Dim (=2) to Complex ## Last Dim -> Complex
if len(info.dims) >= 1: if (
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004 info.dims
operations.append(TO.IntDimToComplex) # Output is Int|Rat|Real
and (
info.output.mathtype
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
)
# Last Axis is Integer of Length 2
and info.last_dim.mathtype is spux.MathType.Integer
and info.has_idx_labels(info.last_dim)
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
):
operations += [TO.IntDimToComplex]
## To Vector ## Last Dim -> Vector
if len(info.dims) >= 1: if len(info.dims) >= 1 and info.output.shape_len == 0:
operations.append(TO.DimToVec) operations += [TO.DimToVec]
## To Matrix ## Last Dim -> Matrix
if len(info.dims) >= 2: # noqa: PLR2004 if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
operations.append(TO.DimsToMat) operations += [TO.DimsToMat]
# Fourier # Fourier
## 1D Fourier if TO.FT1D.valid_dims(info):
if info.dims: operations += [TO.FT1D]
last_physical_type = info.last_dim.physical_type
if last_physical_type == spux.PhysicalType.Time: if TO.InvFT1D.valid_dims(info):
operations.append(TO.FFT1D) operations += [TO.InvFT1D]
if last_physical_type == spux.PhysicalType.Freq:
operations.append(TO.InvFFT1D)
return operations return operations
#################### ####################
# - Function Properties # - Function Properties
#################### ####################
@property @functools.cached_property
def sp_func(self):
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fold
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.t, sim_symbols.freq
),
TO.InvFFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.freq, sim_symbols.t
),
}[self]
@property
def jax_func(self): def jax_func(self):
TO = TransformOperation TO = TransformOperation
return { return {
# Covariant Transform # Covariant Transform
TO.FreqToVacWL: lambda expr: expr, ## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
TO.VacWLToFreq: lambda expr: expr, TO.FreqToVacWL: lambda expr, axis: jnp.flip(expr, axis=axis),
TO.VacWLToFreq: lambda expr, axis: jnp.flip(expr, axis=axis),
TO.ConvertIdxUnit: lambda expr: expr,
TO.SetIdxUnit: lambda expr: expr,
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
# Fold # Fold
## -> To Complex: With a little imagination, this is a noop :) ## -> To Complex: This should generally be a no-op.
## -> **Requires** dims[-1] to be integer-indexed w/length of 2. TO.IntDimToComplex: lambda expr: jnp.squeeze(
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(), expr.view(dtype=jnp.complex64), axis=-1
),
TO.DimToVec: lambda expr: expr, TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr, TO.DimsToMat: lambda expr: expr,
# Fourier # Fourier
TO.FFT1D: lambda expr: jnp.fft(expr), TO.FT1D: lambda expr, axis: jnp.fft(expr, axis=axis),
TO.InvFFT1D: lambda expr: jnp.ifft(expr), TO.InvFT1D: lambda expr, axis: jnp.ifft(expr, axis=axis),
}[self] }[self]
def transform_info( def transform_info(
self, self,
info: ct.InfoFlow | None, info: ct.InfoFlow,
data: jtyp.Shaped[jtyp.Array, '...'] | None = None, dim: sim_symbols.SimSymbol | None = None,
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
new_dim_name: str | None = None,
unit: spux.Unit | None = None, unit: spux.Unit | None = None,
) -> ct.InfoFlow | None: physical_type: spux.PhysicalType | None = None,
) -> ct.InfoFlow:
TO = TransformOperation TO = TransformOperation
if not info.dims:
return None
return { return {
# Covariant Transform # Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim( TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.last_dim), (f_dim := dim),
[ [
sim_symbols.wl(spu.nanometer), sim_symbols.wl(unit),
info.dims[f_dim].rescale( info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spu.nanometer, new_unit=unit,
), ),
], ],
), ),
TO.VacWLToFreq: lambda: info.replace_dim( TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.last_dim), (wl_dim := dim),
[ [
sim_symbols.freq(spux.THz), sim_symbols.freq(unit),
info.dims[wl_dim].rescale( info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spux.THz, new_unit=unit,
), ),
], ],
), ),
TO.ConvertIdxUnit: lambda: info.replace_dim(
dim,
dim.update(unit=unit),
(
info.dims[dim].rescale_to_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.SetIdxUnit: lambda: info.replace_dim(
dim,
dim.update(
sym_name=new_dim_name, physical_type=physical_type, unit=unit
),
(
info.dims[dim].correct_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.FirstColToFirstIdx: lambda: info.replace_dim(
info.first_dim,
info.first_dim.update(
mathtype=spux.MathType.from_jax_array(data_col),
unit=unit,
),
ct.ArrayFlow(values=data_col, unit=unit),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold # Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex mathtype=spux.MathType.Complex
@ -220,21 +323,31 @@ class TransformOperation(enum.StrEnum):
TO.DimToVec: lambda: info.fold_last_input(), TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier # Fourier
TO.FFT1D: lambda: info.replace_dim( TO.FT1D: lambda: info.replace_dim(
info.last_dim, dim,
[ [
sim_symbols.freq(spux.THz), # FT'ed Unit: Reciprocal of the Original Unit
None, dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
info.dims[dim].bound_fourier_transform,
], ],
), ),
TO.InvFFT1D: info.replace_dim( TO.InvFT1D: lambda: info.replace_dim(
info.last_dim, info.last_dim,
[ [
sim_symbols.t(spu.second), # FT'ed Unit: Reciprocal of the Original Unit
None, dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
## -> Note the midpoint may revert to 0.
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
info.dims[dim].bound_inv_fourier_transform,
], ],
), ),
}.get(self, lambda: info)() }[self]()
#################### ####################
@ -274,7 +387,6 @@ class TransformMathNode(base.MaxwellSimNode):
) )
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102 def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr']) has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single( info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending input_sockets['Expr'], ct.FlowSignal.FlowPending
) )
@ -304,45 +416,125 @@ class TransformMathNode(base.MaxwellSimNode):
return [ return [
operation.bl_enum_element(i) operation.bl_enum_element(i)
for i, operation in enumerate( for i, operation in enumerate(
TransformOperation.by_element_shape(self.expr_info) TransformOperation.by_info(self.expr_info)
) )
] ]
return [] return []
####################
# - Properties: Dimension Selection
####################
active_dim: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
def search_dims(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None and self.operation is not None:
return [
(dim.name, dim.name_pretty, dim.name, '', i)
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
]
return []
@bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
def dim(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim is not None:
return self.expr_info.dim_by_name(self.active_dim)
return None
####################
# - Properties: New Dimension Properties
####################
new_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
new_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
active_new_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(),
cb_depends_on={'dim', 'new_physical_type'},
)
def search_units(self) -> list[ct.BLEnumElement]:
if self.dim is not None:
if self.dim.physical_type is not spux.PhysicalType.NonPhysical:
unit_name = sp.sstr(self.dim.unit)
return [
(
sp.sstr(unit),
spux.sp_to_str(unit),
sp.sstr(unit),
'',
0,
)
for unit in self.dim.physical_type.valid_units
]
if self.dim.unit is not None:
unit_name = sp.sstr(self.dim.unit)
return [
(
unit_name,
spux.sp_to_str(self.dim.unit),
unit_name,
'',
0,
)
]
if self.new_physical_type is not spux.PhysicalType.NonPhysical:
return [
(
sp.sstr(unit),
spux.sp_to_str(unit),
sp.sstr(unit),
'',
i,
)
for i, unit in enumerate(self.new_physical_type.valid_units)
]
return []
@bl_cache.cached_bl_property(depends_on={'active_new_unit'})
def new_unit(self) -> spux.Unit:
if self.active_new_unit is not None:
return spux.unit_str_to_unit(self.active_new_unit)
return None
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self): def draw_label(self):
if self.operation is not None: if self.operation is not None:
return 'Transform: ' + TransformOperation.to_name(self.operation) return 'T: ' + TransformOperation.to_name(self.operation)
return self.bl_label return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
if self.operation is not None and self.operation.num_dim_inputs == 1:
TO = TransformOperation
layout.prop(self, self.blfields['active_dim'], text='')
if self.operation in [TO.ConvertIdxUnit, TO.SetIdxUnit]:
col = layout.column(align=True)
if self.operation is TransformOperation.ConvertIdxUnit:
col.prop(self, self.blfields['active_new_unit'], text='')
if self.operation is TransformOperation.SetIdxUnit:
col.prop(self, self.blfields['new_physical_type'], text='')
row = col.row(align=True)
row.prop(self, self.blfields['new_name'], text='')
row.prop(self, self.blfields['active_new_unit'], text='')
#################### ####################
# - Compute: Func / Array # - Compute: Func / Array
#################### ####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
props={'operation'},
input_sockets={'Expr'},
)
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
has_expr_value = not ct.FlowSignal.check(expr)
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None:
return operation.sp_func(expr)
return ct.Flowsignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
@ -354,54 +546,103 @@ class TransformMathNode(base.MaxwellSimNode):
) )
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal: def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
operation = props['operation'] operation = props['operation']
expr = input_sockets['Expr'] lazy_func = input_sockets['Expr']
has_expr = not ct.FlowSignal.check(expr) has_lazy_func = not ct.FlowSignal.check(lazy_func)
if has_expr and operation is not None: if has_lazy_func and operation is not None:
return expr.compose_within( return lazy_func.compose_within(
operation.jax_func, operation.jax_func,
supports_jax=True, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind.Info|Params # - FlowKind.Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'operation'}, props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info}, input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
) )
def compute_info( def compute_info(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]: ) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
operation = props['operation'] operation = props['operation']
info = input_sockets['Expr'] info = input_sockets['Expr'][ct.FlowKind.Info]
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
dim = props['dim']
new_name = props['new_name']
new_unit = props['new_unit']
new_physical_type = props['new_physical_type']
if has_info and operation is not None: if has_info and operation is not None:
transformed_info = operation.transform_info(info) # First Column to First Index
## -> We have to evaluate the lazy function at this point.
## -> It's the only way to get at the column data.
if operation is TransformOperation.FirstColToFirstIdx:
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
params = input_sockets['Expr'][ct.FlowKind.Params]
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_params = not ct.FlowSignal.check(lazy_func)
if transformed_info is None: if has_lazy_func and has_params and not params.symbols:
data = lazy_func.realize(params)
if data.shape is not None and len(data.shape) == 2:
data_col = data[:, 0]
return operation.transform_info(info, data_col=data_col)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return transformed_info
# Check Not-Yet-Updated Dimension
## - Operation changes before dimensions.
## - If InfoFlow is requested in this interim, big problem.
if dim is None and operation.num_dim_inputs > 0:
return ct.FlowSignal.FlowPending
return operation.transform_info(
info,
dim=dim,
new_dim_name=new_name,
unit=new_unit,
physical_type=new_physical_type,
)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
props={'operation', 'dim'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params}, input_socket_kinds={'Expr': {ct.FlowKind.Params, ct.FlowKind.Info}},
) )
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal: def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
has_params = not ct.FlowSignal.check(input_sockets['Expr']) info = input_sockets['Expr'][ct.FlowKind.Info]
if has_params: params = input_sockets['Expr'][ct.FlowKind.Params]
return input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
operation = props['operation']
dim = props['dim']
if has_info and has_params and operation is not None:
# Axis Required: Insert by-Dimension
## -> Some transformations ex. FT require setting an axis.
## -> The user selects which dimension the op should be done along.
## -> This dimension is converted to an axis integer.
## -> Finally, we pass the argument via params.
if operation.num_dim_inputs == 1:
axis = info.dim_axis(dim) if dim is not None else None
return params.compose_within(enclosing_func_args=[axis])
return params
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -21,6 +21,7 @@ import bpy
import jaxtyping as jtyp import jaxtyping as jtyp
import matplotlib.axis as mpl_ax import matplotlib.axis as mpl_ax
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
@ -74,8 +75,33 @@ class VizMode(enum.StrEnum):
SqueezedHeatmap2D = enum.auto() SqueezedHeatmap2D = enum.auto()
Heatmap3D = enum.auto() Heatmap3D = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: def to_name(value: typ.Self) -> str:
return {
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
VizMode.Bar: 'Bar',
VizMode.Curves2D: 'Curves',
VizMode.FilledCurves2D: 'Filled Curves',
VizMode.Heatmap2D: 'Heatmap',
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
VizMode.Heatmap3D: 'Heatmap (3D)',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
####################
# - Validity
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self] | None:
"""Given the input `InfoFlow`, deduce which visualization modes are valid to use with the described data."""
Z = spux.MathType.Integer Z = spux.MathType.Integer
R = spux.MathType.Real R = spux.MathType.Real
VM = VizMode VM = VizMode
@ -102,15 +128,18 @@ class VizMode(enum.StrEnum):
], ],
}.get( }.get(
( (
tuple([dim.mathtype for dim in info.dims.values()]), tuple([dim.mathtype for dim in info.dims]),
(info.output.rows, info.output.cols, info.output.mathtype), (info.output.rows, info.output.cols, info.output.mathtype),
), ),
[], [],
) )
@staticmethod ####################
def to_plotter( # - Properties
value: typ.Self, ####################
@property
def mpl_plotter(
self,
) -> typ.Callable[ ) -> typ.Callable[
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]: ]:
@ -124,25 +153,7 @@ class VizMode(enum.StrEnum):
VizMode.Heatmap2D: image_ops.plot_heatmap_2d, VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
# NO PLOTTER: VizMode.SqueezedHeatmap2D # NO PLOTTER: VizMode.SqueezedHeatmap2D
# NO PLOTTER: VizMode.Heatmap3D # NO PLOTTER: VizMode.Heatmap3D
}[value] }[self]
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
VizMode.Bar: 'Bar',
VizMode.Curves2D: 'Curves',
VizMode.FilledCurves2D: 'Filled Curves',
VizMode.Heatmap2D: 'Heatmap',
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
VizMode.Heatmap3D: 'Heatmap (3D)',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
class VizTarget(enum.StrEnum): class VizTarget(enum.StrEnum):
@ -181,6 +192,10 @@ class VizTarget(enum.StrEnum):
return '' return ''
sym_x_um = sim_symbols.space_x(spu.um)
x_um = sym_x_um.sp_symbol
class VizNode(base.MaxwellSimNode): class VizNode(base.MaxwellSimNode):
"""Node for visualizing simulation data, by querying its monitors. """Node for visualizing simulation data, by querying its monitors.
@ -188,7 +203,6 @@ class VizNode(base.MaxwellSimNode):
Attributes: Attributes:
colormap: Colormap to apply to 0..1 output. colormap: Colormap to apply to 0..1 output.
""" """
node_type = ct.NodeType.Viz node_type = ct.NodeType.Viz
@ -201,8 +215,8 @@ class VizNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func, active_kind=ct.FlowKind.Func,
default_symbols=[sim_symbols.x], default_symbols=[sym_x_um],
default_value=2 * sim_symbols.x.sp_symbol, default_value=sp.exp(-(x_um**2)),
), ),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
@ -240,16 +254,21 @@ class VizNode(base.MaxwellSimNode):
return None return None
viz_mode: enum.StrEnum = bl_cache.BLField( viz_mode: VizMode = bl_cache.BLField(
enum_cb=lambda self, _: self.search_viz_modes(), enum_cb=lambda self, _: self.search_viz_modes(),
cb_depends_on={'expr_info'}, cb_depends_on={'expr_info'},
) )
viz_target: enum.StrEnum = bl_cache.BLField( viz_target: VizTarget = bl_cache.BLField(
enum_cb=lambda self, _: self.search_targets(), enum_cb=lambda self, _: self.search_targets(),
cb_depends_on={'viz_mode'}, cb_depends_on={'viz_mode'},
) )
# Mode-Dependent Properties # Plot
plot_width: float = bl_cache.BLField(6.0, abs_min=0.1)
plot_height: float = bl_cache.BLField(3.0, abs_min=0.1)
plot_dpi: int = bl_cache.BLField(150, abs_min=25)
# Pixels
colormap: image_ops.Colormap = bl_cache.BLField( colormap: image_ops.Colormap = bl_cache.BLField(
image_ops.Colormap.Viridis, image_ops.Colormap.Viridis,
) )
@ -267,7 +286,7 @@ class VizNode(base.MaxwellSimNode):
VizMode.to_icon(viz_mode), VizMode.to_icon(viz_mode),
i, i,
) )
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info)) for i, viz_mode in enumerate(VizMode.by_info(self.expr_info))
] ]
return [] return []
@ -300,9 +319,22 @@ class VizNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, self.blfields['viz_mode'], text='') col.prop(self, self.blfields['viz_mode'], text='')
col.prop(self, self.blfields['viz_target'], text='') col.prop(self, self.blfields['viz_target'], text='')
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]: if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
col.prop(self, self.blfields['colormap'], text='') col.prop(self, self.blfields['colormap'], text='')
if self.viz_target is VizTarget.Plot2D:
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Width/Height/DPI')
row = col.row(align=True)
row.prop(self, self.blfields['plot_width'], text='')
row.prop(self, self.blfields['plot_height'], text='')
row = col.row(align=True)
col.prop(self, self.blfields['plot_dpi'], text='')
#################### ####################
# - Events # - Events
#################### ####################
@ -320,16 +352,16 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params) has_params = not ct.FlowSignal.check(params)
# Provide Sockets for Symbol Realization # Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols. ## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols: if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != { if set(self.loose_input_sockets) != {
dim.name for dim in params.symbols if dim in info.dims sym.name for sym in params.symbols if sym in info.dims
}: }:
self.loose_input_sockets = { self.loose_input_sockets = {
dim_name: sockets.ExprSocketDef(**expr_info) dim_name: sockets.ExprSocketDef(**expr_info)
for dim_name, expr_info in params.sym_expr_infos( for dim_name, expr_info in params.sym_expr_infos(
info, use_range=True use_range=True
).items() ).items()
} }
@ -343,7 +375,14 @@ class VizNode(base.MaxwellSimNode):
'Preview', 'Preview',
kind=ct.FlowKind.Value, kind=ct.FlowKind.Value,
# Loaded # Loaded
props={'viz_mode', 'viz_target', 'colormap'}, props={
'viz_mode',
'viz_target',
'colormap',
'plot_width',
'plot_height',
'plot_dpi',
},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params} 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
@ -359,7 +398,14 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'}, props={
'viz_mode',
'viz_target',
'colormap',
'plot_width',
'plot_height',
'plot_dpi',
},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params} 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
@ -370,7 +416,6 @@ class VizNode(base.MaxwellSimNode):
def on_show_plot( def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets self, managed_objs, props, input_sockets, loose_input_sockets
) -> None: ) -> None:
# Retrieve Inputs
lazy_func = input_sockets['Expr'][ct.FlowKind.Func] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params] params = input_sockets['Expr'][ct.FlowKind.Params]
@ -378,51 +423,59 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params) has_params = not ct.FlowSignal.check(params)
if ( plot = managed_objs['plot']
not has_info viz_mode = props['viz_mode']
or not has_params viz_target = props['viz_target']
or props['viz_mode'] is None if has_info and has_params and viz_mode is not None and viz_target is not None:
or props['viz_target'] is None # Realize Data w/Realized Symbols
): ## -> The loose input socket values are user-selected symbol values.
return ## -> These expressions are used to realize the lazy data.
## -> `.realize()` ensures all ex. units are correctly conformed.
realized_syms = {
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
}
output_data = lazy_func.realize(params, symbol_values=realized_syms)
# Compute Ranges for Symbols from Loose Sockets data = {
## -> In a quite nice turn of events, all this is cached lookups. dim: (
## -> ...Unless something changed, in which case, well. It changed. realized_syms[dim].values
symbol_array_values = { if dim in realized_syms
sim_syms: ( else info.dims[dim]
loose_input_sockets[sim_syms]
.rescale_to_unit(sim_syms.unit)
.realize_array
)
for sim_syms in params.sorted_symbols
}
data = lazy_func.realize(params, symbol_values=symbol_array_values)
# Replace InfoFlow Indices w/Realized Symbolic Ranges
## -> This ensures correct axis scaling.
if params.symbols:
info = info.replace_dims(symbol_array_values)
match props['viz_target']:
case VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True,
) )
for dim in info.dims
} | {info.output: output_data}
case VizTarget.Pixels: # Match Viz Type & Perform Visualization
managed_objs['plot'].map_2d_to_image( ## -> Viz Target determines how to plot.
data, ## -> Viz Mode may help select a particular plotting method.
colormap=props['colormap'], ## -> Other parameters may be uses, depending on context.
bl_select=True, match viz_target:
) case VizTarget.Plot2D:
plot_width = props['plot_width']
plot_height = props['plot_height']
plot_dpi = props['plot_dpi']
plot.mpl_plot_to_image(
lambda ax: viz_mode.mpl_plotter(data, ax),
width_inches=plot_width,
height_inches=plot_height,
dpi=plot_dpi,
bl_select=True,
)
case VizTarget.PixelsPlane: case VizTarget.Pixels:
raise NotImplementedError colormap = props['colormap']
if colormap is not None:
plot.map_2d_to_image(
data,
colormap=colormap,
bl_select=True,
)
case VizTarget.Voxels: case VizTarget.PixelsPlane:
raise NotImplementedError raise NotImplementedError
case VizTarget.Voxels:
raise NotImplementedError
#################### ####################

View File

@ -610,8 +610,8 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Anyone needing results will need to wait on preinit(). ## -> Anyone needing results will need to wait on preinit().
return ct.FlowSignal.FlowInitializing return ct.FlowSignal.FlowInitializing
if optional: # if optional:
return ct.FlowSignal.NoFlow return ct.FlowSignal.NoFlow
msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket' msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket'
raise ValueError(msg) raise ValueError(msg)
@ -659,11 +659,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
return output_socket_methods[0](self) return output_socket_methods[0](self)
# Auxiliary Fallbacks # Auxiliary Fallbacks
if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]: return ct.FlowSignal.NoFlow
return ct.FlowSignal.NoFlow # if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
# return ct.FlowSignal.NoFlow
msg = f'No output method for ({output_socket_name}, {kind})' # msg = f'No output method for ({output_socket_name}, {kind})'
raise ValueError(msg) # raise ValueError(msg)
#################### ####################
# - Event Trigger # - Event Trigger

View File

@ -30,6 +30,7 @@ class ExprConstantNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func, active_kind=ct.FlowKind.Func,
show_name_selector=True,
), ),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {

View File

@ -17,8 +17,11 @@
import typing as typ import typing as typ
import bpy import bpy
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, sci_constants from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
@ -30,15 +33,19 @@ class ScientificConstantNode(base.MaxwellSimNode):
bl_label = 'Scientific Constant' bl_label = 'Scientific Constant'
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Value': sockets.ExprSocketDef(), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
#################### ####################
# - Properties # - Properties
#################### ####################
sci_constant: str = bl_cache.BLField( use_symbol: bool = bl_cache.BLField(False)
sci_constant_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerU
)
sci_constant_str: str = bl_cache.BLField(
'', '',
prop_ui=True,
str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text), str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text),
) )
@ -52,27 +59,139 @@ class ScientificConstantNode(base.MaxwellSimNode):
if edit_text.lower() in name.lower() if edit_text.lower() in name.lower()
] ]
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
def sci_constant(self) -> spux.SympyExpr | None:
"""Retrieve the expression for the scientific constant."""
return sci_constants.SCI_CONSTANTS.get(self.sci_constant_str)
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
def sci_constant_info(self) -> spux.SympyExpr | None:
"""Retrieve the information for the selected scientific constant."""
return sci_constants.SCI_CONSTANTS_INFO.get(self.sci_constant_str)
@bl_cache.cached_bl_property(
depends_on={'sci_constant', 'sci_constant_info', 'sci_constant_name'}
)
def sci_constant_sym(self) -> spux.SympyExpr | None:
"""Retrieve a symbol for the scientific constant."""
if self.sci_constant is not None and self.sci_constant_info is not None:
unit = self.sci_constant_info['units']
return sim_symbols.SimSymbol(
sym_name=self.sci_constant_name,
mathtype=spux.MathType.from_expr(self.sci_constant),
# physical_type= ## TODO: Formalize unit w/o physical_type
unit=unit,
is_constant=True,
)
return None
#################### ####################
# - UI # - UI
#################### ####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
col.prop(self, self.blfields['sci_constant'], text='') col.prop(self, self.blfields['sci_constant_str'], text='')
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Assign Symbol')
col.prop(self, self.blfields['sci_constant_name'], text='')
col.prop(self, self.blfields['use_symbol'], text='Use Symbol', toggle=True)
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
if self.sci_constant: box = col.box()
col.label( split = box.split(factor=0.25, align=True)
text=f'Units: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["units"]}'
) # Left: Units
col.label( _col = split.column(align=True)
text=f'Uncertainty: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["uncertainty"]}' row = _col.row(align=True)
) # row.alignment = 'CENTER'
row.label(text='Src')
if self.sci_constant_info:
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='Unit')
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='Err')
# Right: Values
_col = split.column(align=True)
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='CODATA2018')
if self.sci_constant_info:
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text=f'{self.sci_constant_info["units"]}')
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
#################### ####################
# - Output # - Output
#################### ####################
@events.computes_output_socket('Value', props={'sci_constant'}) @events.computes_output_socket(
def compute_value(self, props: dict) -> typ.Any: 'Expr',
return sci_constants.SCI_CONSTANTS[props['sci_constant']] props={'use_symbol', 'sci_constant', 'sci_constant_sym'},
)
def compute_value(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
if props['use_symbol'] and sci_constant_sym is not None:
return sci_constant_sym.sp_symbol
if sci_constant is not None:
return sci_constant
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
props={'sci_constant', 'sci_constant_sym'},
)
def compute_lazy_func(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
if sci_constant is not None:
return ct.FuncFlow(
func=sp.lambdify(
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
),
func_args=[sci_constant_sym],
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'sci_constant_sym'},
)
def compute_info(self, props: dict) -> typ.Any:
sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None:
return ct.InfoFlow(output=sci_constant_sym)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
props={'sci_constant'},
)
def compute_params(self, props: dict) -> typ.Any:
sci_constant = props['sci_constant']
if sci_constant is not None:
return ct.ParamsFlow(func_args=[sci_constant])
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -95,62 +95,35 @@ class DataFileImporterNode(base.MaxwellSimNode):
#################### ####################
# - Info Guides # - Info Guides
#################### ####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName) output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) sim_symbols.SimSymbolName.Data
)
output_mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
output_physical_type: spux.PhysicalType = bl_cache.BLField( output_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical spux.PhysicalType.NonPhysical
) )
output_unit: enum.StrEnum = bl_cache.BLField( output_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type), enum_cb=lambda self, _: self.search_units(self.output_physical_type),
cb_depends_on={'output_physical_type'}, cb_depends_on={'output_physical_type'},
) )
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField( dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA sim_symbols.SimSymbolName.LowerA
) )
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_0_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'dim_0_physical_type'},
)
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField( dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerB sim_symbols.SimSymbolName.LowerB
) )
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_1_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
cb_depends_on={'dim_1_physical_type'},
)
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField( dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerC sim_symbols.SimSymbolName.LowerC
) )
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_2_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
cb_depends_on={'dim_2_physical_type'},
)
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField( dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerD sim_symbols.SimSymbolName.LowerD
) )
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) dim_4_name: sim_symbols.SimSymbolName = bl_cache.BLField(
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField( sim_symbols.SimSymbolName.LowerE
spux.PhysicalType.NonPhysical
) )
dim_3_unit: enum.StrEnum = bl_cache.BLField( dim_5_name: sim_symbols.SimSymbolName = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type), sim_symbols.SimSymbolName.LowerF
cb_depends_on={'dim_3_physical_type'},
) )
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]: def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
@ -161,19 +134,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
] ]
return [] return []
def dim(self, i: int):
dim_name = getattr(self, f'dim_{i}_name')
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
dim_unit = getattr(self, f'dim_{i}_unit')
return sim_symbols.SimSymbol(
sym_name=dim_name,
mathtype=dim_mathtype,
physical_type=dim_physical_type,
unit=spux.unit_str_to_unit(dim_unit),
)
#################### ####################
# - UI # - UI
#################### ####################
@ -202,19 +162,21 @@ class DataFileImporterNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw loaded properties.""" """Draw loaded properties."""
for i in range(len(self.expr_info.dims)): col = layout.column(align=True)
col = layout.column(align=True) if self.expr_info is not None:
for i in range(len(self.expr_info.dims)):
col.prop(self, self.blfields[f'dim_{i}_name'], text=f'Dim {i}')
row = col.row(align=True) row = col.row(align=True)
row.alignment = 'CENTER' row.alignment = 'CENTER'
row.label(text=f'Load Dim {i}') row.label(text='Output')
row = col.row(align=True) row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_name'], text='') row.prop(self, self.blfields['output_name'], text='')
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='') row.prop(self, self.blfields['output_mathtype'], text='')
row = col.row(align=True) row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='') row.prop(self, self.blfields['output_physical_type'], text='')
row.prop(self, self.blfields[f'dim_{i}_unit'], text='') row.prop(self, self.blfields['output_unit'], text='')
#################### ####################
# - FlowKind.Array|Func # - FlowKind.Array|Func
@ -271,7 +233,8 @@ class DataFileImporterNode(base.MaxwellSimNode):
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
# Loaded # Loaded
props={'output_name', 'output_physical_type', 'output_unit'}, props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
| {f'dim_{i}_name' for i in range(6)},
output_sockets={'Expr'}, output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Func}, output_socket_kinds={'Expr': ct.FlowKind.Func},
) )
@ -285,32 +248,31 @@ class DataFileImporterNode(base.MaxwellSimNode):
A completely empty `ParamsFlow`, ready to be composed. A completely empty `ParamsFlow`, ready to be composed.
""" """
expr = output_sockets['Expr'] expr = output_sockets['Expr']
has_expr_func = not ct.FlowSignal.check(expr) has_expr_func = not ct.FlowSignal.check(expr)
if has_expr_func: if has_expr_func:
data = expr.func_jax() data = expr.func_jax()
# Deduce Dimensionality # Deduce Dimension Symbols
_shape = data.shape ## -> They are all chronically integer indices.
shape = _shape if _shape is not None else () ## -> The FilterNode can be used to "steal" an index from the data.
dim_syms = [self.dim(i) for i in range(len(shape))] shape = data.shape if data.shape is not None else ()
dims = {
sim_symbols.idx(None).update(
sym_name=props[f'dim_{i}_name'],
interval_finite_z=(0, elements),
interval_inf=(False, False),
interval_closed=(True, True),
): [str(j) for j in range(elements)]
for i, elements in enumerate(shape)
}
# Return InfoFlow
return ct.InfoFlow( return ct.InfoFlow(
dims={ dims=dims,
dim_sym: ct.RangeFlow(
start=sp.S(0),
stop=sp.S(shape[i] - 1),
steps=shape[i],
unit=self.dim(i).unit,
)
for i, dim_sym in enumerate(dim_syms)
},
output=sim_symbols.SimSymbol( output=sim_symbols.SimSymbol(
sym_name=props['output_name'], sym_name=props['output_name'],
mathtype=props['output_mathtype'], mathtype=props['output_mathtype'],
physical_type=props['output_physical_type'], physical_type=props['output_physical_type'],
unit=props['output_unit'],
), ),
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -259,9 +259,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
) )
def compute_valid_freqs_lazy(self, props) -> sp.Expr: def compute_valid_freqs_lazy(self, props) -> sp.Expr:
return ct.RangeFlow( return ct.RangeFlow(
start=props['freq_range'][0] / spux.THz, start=spu.scale_to_unit(['freq_range'][0], spux.THz),
stop=props['freq_range'][1] / spux.THz, stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
steps=0,
scaling=ct.ScalingMode.Lin, scaling=ct.ScalingMode.Lin,
unit=spux.THz, unit=spux.THz,
) )
@ -273,9 +272,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
) )
def compute_valid_wls_lazy(self, props) -> sp.Expr: def compute_valid_wls_lazy(self, props) -> sp.Expr:
return ct.RangeFlow( return ct.RangeFlow(
start=props['wl_range'][0] / spu.nm, start=spu.scale_to_unit(['wl_range'][0], spu.nm),
stop=props['wl_range'][0] / spu.nm, stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
steps=0,
scaling=ct.ScalingMode.Lin, scaling=ct.ScalingMode.Lin,
unit=spu.nm, unit=spu.nm,
) )

View File

@ -73,43 +73,138 @@ class ViewerNode(base.MaxwellSimNode):
#################### ####################
# - Properties # - Properties
#################### ####################
print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value) auto_expr: bool = bl_cache.BLField(True)
auto_plot: bool = bl_cache.BLField(False) debug_mode: bool = bl_cache.BLField(False)
# Debug Mode
console_print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
auto_plot: bool = bl_cache.BLField(True)
auto_3d_preview: bool = bl_cache.BLField(True) auto_3d_preview: bool = bl_cache.BLField(True)
####################
# - Properties: Computed FlowKinds
####################
@events.on_value_changed(
socket_name='Any',
)
def on_input_changed(self) -> None:
self.input_flow = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def input_flow(self) -> dict[ct.FlowKind, typ.Any | None]:
input_flow = {}
for flow_kind in list(ct.FlowKind):
flow = self._compute_input('Any', kind=flow_kind)
has_flow = not ct.FlowSignal.check(flow)
if has_flow:
input_flow |= {flow_kind: flow}
else:
input_flow |= {flow_kind: None}
return input_flow
####################
# - Property: Input Expression String Lines
####################
@bl_cache.cached_bl_property(depends_on={'input_flow'})
def input_expr_str_entries(self) -> list[list[str]] | None:
value = self.input_flow.get(ct.FlowKind.Value)
def sp_pretty(v: spux.SympyExpr) -> spux.SympyExpr:
## sp.pretty makes new lines and wreaks havoc.
return spux.sp_to_str(v.n(4))
if isinstance(value, spux.SympyType):
if isinstance(value, sp.MatrixBase):
return [
[sp_pretty(value[row, col]) for col in range(value.shape[1])]
for row in range(value.shape[0])
]
return [[sp_pretty(value)]]
return None
#################### ####################
# - UI # - UI
#################### ####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout): def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['print_kind'], text='') row = layout.row(align=True)
# Debug Mode On/Off
row.prop(self, self.blfields['debug_mode'], text='Debug', toggle=True)
# Automatic Expression Printing
row.prop(self, self.blfields['auto_expr'], text='Expr', toggle=True)
# Debug Mode Operators
if self.debug_mode:
layout.prop(self, self.blfields['console_print_kind'], text='')
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout): def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
split = layout.split(factor=0.4) # Live Expression
if self.debug_mode:
layout.operator(ConsoleViewOperator.bl_idname, text='Console Print')
# Split LHS split = layout.split(factor=0.4)
col = split.column(align=False)
col.label(text='Console')
col.label(text='Plot')
col.label(text='3D')
# Split RHS # Split LHS
col = split.column(align=False) col = split.column(align=False)
col.label(text='Plot')
col.label(text='3D')
## Console Options # Split RHS
col.operator(ConsoleViewOperator.bl_idname, text='Print') col = split.column(align=False)
## Plot Options ## Plot Options
row = col.row(align=True) row = col.row(align=True)
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True) row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
row.operator( row.operator(
RefreshPlotViewOperator.bl_idname, RefreshPlotViewOperator.bl_idname,
text='', text='',
icon='FILE_REFRESH', icon='FILE_REFRESH',
) )
## 3D Preview Options ## 3D Preview Options
row = col.row(align=True) row = col.row(align=True)
row.prop(self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True) row.prop(
self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True
)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
# Live Expression
if self.auto_expr and self.input_expr_str_entries is not None:
box = layout.box()
expr_rows = len(self.input_expr_str_entries)
expr_cols = len(self.input_expr_str_entries[0])
shape_str = (
f'({expr_rows}×{expr_cols})'
if expr_rows != 1 or expr_cols != 1
else '(Scalar)'
)
row = box.row()
row.alignment = 'CENTER'
row.label(text=f'Expr {shape_str}')
if (
len(self.input_expr_str_entries) == 1
and len(self.input_expr_str_entries[0]) == 1
):
row = box.row()
row.alignment = 'CENTER'
row.label(text=self.input_expr_str_entries[0][0])
else:
grid = box.grid_flow(
row_major=True,
columns=len(self.input_expr_str_entries[0]),
align=True,
)
for row in self.input_expr_str_entries:
for entry in row:
grid.label(text=entry)
#################### ####################
# - Methods # - Methods
@ -119,7 +214,7 @@ class ViewerNode(base.MaxwellSimNode):
return return
log.info('Printing to Console') log.info('Printing to Console')
data = self._compute_input('Any', kind=self.print_kind, optional=True) data = self._compute_input('Any', kind=self.console_print_kind, optional=True)
if isinstance(data, spux.SympyType): if isinstance(data, spux.SympyType):
console.print(sp.pretty(data, use_unicode=True)) console.print(sp.pretty(data, use_unicode=True))

View File

@ -40,6 +40,8 @@ _max_e_socket_def = sockets.ExprSocketDef(
) )
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5) _offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
t_ps = sim_symbols.t(spu.picosecond)
class TemporalShapeNode(base.MaxwellSimNode): class TemporalShapeNode(base.MaxwellSimNode):
"""Declare a source-time dependence for use in simulation source nodes.""" """Declare a source-time dependence for use in simulation source nodes."""
@ -82,8 +84,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_steps=100, default_steps=100,
), ),
'Envelope': sockets.ExprSocketDef( 'Envelope': sockets.ExprSocketDef(
default_symbols=[sim_symbols.t], default_symbols=[t_ps],
default_value=10 * sim_symbols.t.sp_symbol, default_value=10 * t_ps.sp_symbol,
), ),
}, },
} }

View File

@ -593,6 +593,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
ValueError: When referencing a socket that's meant to be directly referenced. ValueError: When referencing a socket that's meant to be directly referenced.
""" """
kind_data_map = { kind_data_map = {
ct.FlowKind.Capabilities: lambda: self.capabilities,
ct.FlowKind.Value: lambda: self.value, ct.FlowKind.Value: lambda: self.value,
ct.FlowKind.Array: lambda: self.array, ct.FlowKind.Array: lambda: self.array,
ct.FlowKind.Func: lambda: self.lazy_func, ct.FlowKind.Func: lambda: self.lazy_func,

View File

@ -111,29 +111,38 @@ class ExprBLSocket(base.MaxwellSimSocket):
bl_label = 'Expr' bl_label = 'Expr'
#################### ####################
# - Properties # - Socket Interface
#################### ####################
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar) size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real) mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
# Symbols ####################
# - Symbols
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField( output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr sim_symbols.SimSymbolName.Expr
) )
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
@property
def symbols(self) -> set[sp.Symbol]:
"""Current symbols as an unordered set."""
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'}) @bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sp.Symbol]: def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
"""Sympy symbols as an unordered set."""
return {sim_symbol.sp_symbol_matsym for sim_symbol in self.symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Current symbols as a sorted list.""" """Current symbols as a sorted list."""
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
# Unit @bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
"""Computes `sympy` symbols from `self.sorted_symbols`."""
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
####################
# - Units
####################
active_unit: enum.StrEnum = bl_cache.BLField( active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(), enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'physical_type'}, cb_depends_on={'physical_type'},
@ -148,6 +157,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
] ]
return [] return []
@bl_cache.cached_bl_property(depends_on={'active_unit'})
def unit(self) -> spux.Unit | None:
"""Gets the current active unit.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_unit is not None:
return spux.unit_str_to_unit(self.active_unit)
return None
@property
def unit_factor(self) -> spux.Unit | None:
return sp.Integer(1) if self.unit is None else self.unit
prev_unit: str | None = bl_cache.BLField(None)
####################
# - UI Values
####################
# UI: Value # UI: Value
## Expression ## Expression
raw_value_spstr: str = bl_cache.BLField('0.0') raw_value_spstr: str = bl_cache.BLField('0.0')
@ -186,6 +218,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# UI: Info # UI: Info
show_name_selector: bool = bl_cache.BLField(False)
show_func_ui: bool = bl_cache.BLField(True) show_func_ui: bool = bl_cache.BLField(True)
show_info_columns: bool = bl_cache.BLField(False) show_info_columns: bool = bl_cache.BLField(False)
info_columns: set[InfoDisplayCol] = bl_cache.BLField( info_columns: set[InfoDisplayCol] = bl_cache.BLField(
@ -207,25 +240,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
def raw_max_sp(self) -> spux.SympyExpr: def raw_max_sp(self) -> spux.SympyExpr:
return self._parse_expr_str(self.raw_max_spstr) return self._parse_expr_str(self.raw_max_spstr)
####################
# - Computed Unit
####################
@bl_cache.cached_bl_property(depends_on={'active_unit'})
def unit(self) -> spux.Unit | None:
"""Gets the current active unit.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_unit is not None:
return spux.unit_str_to_unit(self.active_unit)
return None
prev_unit: str | None = bl_cache.BLField(None)
#################### ####################
# - Prop-Change Callback # - Prop-Change Callback
#################### ####################
@ -272,7 +286,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
raise ValueError(msg) raise ValueError(msg)
# Parse Symbols # Parse Symbols
if expr.free_symbols and not expr.free_symbols.issubset(self.symbols): if expr.free_symbols and not expr.free_symbols.issubset(self.sp_symbols):
msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}' msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}'
raise ValueError(msg) raise ValueError(msg)
@ -320,7 +334,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
""" """
expr = sp.sympify( expr = sp.sympify(
expr_spstr, expr_spstr,
locals={sym.name: sym for sym in self.symbols}, locals={sym.name: sym.sp_symbol_matsym for sym in self.symbols},
strict=False, strict=False,
convert_xor=True, convert_xor=True,
).subs(spux.UNIT_BY_SYMBOL) ).subs(spux.UNIT_BY_SYMBOL)
@ -562,11 +576,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
if self.symbols: if self.symbols:
return ct.FuncFlow( return ct.FuncFlow(
func=sp.lambdify( func=sp.lambdify(
self.sorted_symbols, self.sorted_sp_symbols,
spux.scale_to_unit(self.value, self.unit), spux.strip_unit_system(self.value),
'jax', 'jax',
), ),
func_args=[spux.MathType.from_expr(sym) for sym in self.sorted_symbols], func_args=list(self.sorted_symbols),
supports_jax=True, supports_jax=True,
) )
@ -578,7 +592,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
return ct.FuncFlow( return ct.FuncFlow(
func=lambda v: v, func=lambda v: v,
func_args=[ func_args=[
self.physical_type if self.physical_type is not None else self.mathtype sim_symbols.SimSymbol.from_expr(
sim_symbols.SimSymbolName.Constant, self.value, self.unit_factor
)
], ],
supports_jax=True, supports_jax=True,
) )
@ -597,8 +613,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
## -> NOTE: func_args must have the same symbol order as was lambdified. ## -> NOTE: func_args must have the same symbol order as was lambdified.
if self.symbols: if self.symbols:
return ct.ParamsFlow( return ct.ParamsFlow(
func_args=self.sorted_symbols, func_args=[sym.sp_symbol_phy for sym in self.sorted_symbols],
symbols=self.symbols, symbols=self.sorted_symbols,
) )
# Constant # Constant
@ -618,24 +634,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
""" """
output_sim_sym = ( output_sym = sim_symbols.SimSymbol(
sim_symbols.SimSymbol( sym_name=self.output_name,
sym_name=self.output_name, mathtype=self.mathtype,
mathtype=self.mathtype, physical_type=self.physical_type,
physical_type=self.physical_type, unit=self.unit,
unit=self.unit, rows=self.size.rows,
rows=self.size.rows, cols=self.size.cols,
cols=self.size.cols,
),
) )
# Constant
## -> The input SimSymbols become continuous dimensional indices.
## -> All domain validity information is defined on the SimSymbol keys.
if self.symbols: if self.symbols:
return ct.InfoFlow( return ct.InfoFlow(
dims={sim_sym: None for sim_sym in self.active_symbols}, dims={sym: None for sym in self.sorted_symbols},
output=output_sim_sym, output=output_sym,
) )
# Constant # Constant
return ct.InfoFlow(output=output_sim_sym) ## -> We only need the output symbol to describe the raw data.
return ct.InfoFlow(output=output_sym)
#################### ####################
# - FlowKind: Capabilities # - FlowKind: Capabilities
@ -645,6 +664,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
return ct.CapabilitiesFlow( return ct.CapabilitiesFlow(
socket_type=self.socket_type, socket_type=self.socket_type,
active_kind=self.active_kind, active_kind=self.active_kind,
allow_out_to_in={ct.FlowKind.Func: ct.FlowKind.Value},
) )
#################### ####################
@ -795,7 +815,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
col = split.column() col = split.column()
col.alignment = 'RIGHT' col.alignment = 'RIGHT'
for sym in self.symbols: for sym in self.symbols:
col.label(text=spux.pretty_symbol(sym)) col.label(text=sym.def_label)
def draw_lazy_range(self, col: bpy.types.UILayout) -> None: def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
"""Draw the socket body for a simple, uniform range of values between two values/expressions. """Draw the socket body for a simple, uniform range of values between two values/expressions.
@ -840,14 +860,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
Uses `draw_value` to draw the base UI Uses `draw_value` to draw the base UI
""" """
if self.show_func_ui: if self.show_func_ui:
# Output Name Selector
## -> The name of the output
col.prop(self, self.blfields['output_name'], text='')
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
# Non-Symbolic: Size/Mathtype Selector # Non-Symbolic: Size/Mathtype Selector
## -> Symbols imply str expr input. ## -> Symbols imply str expr input.
## -> For arbitrary str exprs, size/mathtype are derived from the expr. ## -> For arbitrary str exprs, size/mathtype are derived from the expr.
@ -861,15 +873,30 @@ class ExprBLSocket(base.MaxwellSimSocket):
## -> Draws the UI appropriate for the above choice of constraints. ## -> Draws the UI appropriate for the above choice of constraints.
self.draw_value(col) self.draw_value(col)
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
# Symbol UI # Symbol UI
## -> Draws the UI appropriate for the above choice of constraints. ## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO ## -> TODO
# Output Name Selector
## -> The name of the output
if self.show_name_selector:
row = col.row()
row.prop(self, self.blfields['output_name'], text='Name')
#################### ####################
# - UI: InfoFlow # - UI: InfoFlow
#################### ####################
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None: def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
if self.active_kind == ct.FlowKind.Func and self.show_info_columns: """Visualize the `InfoFlow` information passing through the socket."""
if (
self.active_kind == ct.FlowKind.Func
and self.show_info_columns
and self.is_linked
):
row = col.row() row = col.row()
box = row.box() box = row.box()
grid = box.grid_flow( grid = box.grid_flow(
@ -881,38 +908,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# Dimensions # Dimensions
for dim in info.dims: for dim_name_pretty, dim_label_info in info.dim_labels.items():
dim_idx = info.dims[dim] grid.label(text=dim_name_pretty)
grid.label(text=dim.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text=str(len(dim_idx))) grid.label(text=dim_label_info['length'])
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
grid.label(text=spux.MathType.to_str(dim_idx.mathtype)) grid.label(text=dim_label_info['mathtype'])
if InfoDisplayCol.Unit in self.info_columns: if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=spux.sp_to_str(dim_idx.unit)) grid.label(text=dim_label_info['unit'])
# Outputs # Outputs
grid.label(text=info.output.name_pretty) grid.label(text=info.output.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput) grid.label(text='', icon=ct.Icon.DataSocketOutput)
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
grid.label( grid.label(text=info.output.def_label)
text=(
spux.MathType.to_str(info.output.mathtype)
+ (
'ˣ'.join(
[
unicode_superscript(out_axis)
for out_axis in info.output.shape
]
)
if info.output.shape
else ''
)
)
)
if InfoDisplayCol.Unit in self.info_columns: if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=f'{spux.sp_to_str(info.output.unit)}') grid.label(text=info.output.unit_label)
#################### ####################
@ -926,7 +938,7 @@ class ExprSocketDef(base.SocketDef):
ct.FlowKind.Array, ct.FlowKind.Array,
ct.FlowKind.Func, ct.FlowKind.Func,
] = ct.FlowKind.Value ] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
# Socket Interface # Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -948,9 +960,15 @@ class ExprSocketDef(base.SocketDef):
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# UI # UI
show_name_selector: bool = False
show_func_ui: bool = True show_func_ui: bool = True
show_info_columns: bool = False show_info_columns: bool = False
@property
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
"""Default symbols as an unordered set."""
return {sym.sp_symbol_matsym for sym in self.default_symbols}
#################### ####################
# - Parse Unit and/or Physical Type # - Parse Unit and/or Physical Type
#################### ####################
@ -1149,12 +1167,13 @@ class ExprSocketDef(base.SocketDef):
raise ValueError(msg) raise ValueError(msg)
# Coerce from Infinite # Coerce from Infinite
if bound.is_infinite and self.mathtype is spux.MathType.Integer: if isinstance(bound, spux.SympyType):
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1) if bound.is_infinite and self.mathtype is spux.MathType.Integer:
if bound.is_infinite and self.mathtype is spux.MathType.Rational: new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1) if bound.is_infinite and self.mathtype is spux.MathType.Rational:
if bound.is_infinite and self.mathtype is spux.MathType.Real: new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0) if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None: if new_bounds[0] is not None:
self.default_min = new_bounds[0] self.default_min = new_bounds[0]
@ -1194,9 +1213,9 @@ class ExprSocketDef(base.SocketDef):
def symbols_value(self) -> typ.Self: def symbols_value(self) -> typ.Self:
if ( if (
self.default_value.free_symbols self.default_value.free_symbols
and not self.default_value.free_symbols.issubset(self.symbols) and not self.default_value.free_symbols.issubset(self.sp_symbols)
): ):
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.symbols}' msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.sp_symbols}'
raise ValueError(msg) raise ValueError(msg)
return self return self
@ -1227,7 +1246,7 @@ class ExprSocketDef(base.SocketDef):
bl_socket.size = self.size bl_socket.size = self.size
bl_socket.mathtype = self.mathtype bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type bl_socket.physical_type = self.physical_type
bl_socket.active_symbols = self.symbols bl_socket.symbols = self.default_symbols
# FlowKind.Value # FlowKind.Value
## -> We must take units into account when setting bl_socket.value ## -> We must take units into account when setting bl_socket.value
@ -1252,6 +1271,7 @@ class ExprSocketDef(base.SocketDef):
# UI # UI
bl_socket.show_func_ui = self.show_func_ui bl_socket.show_func_ui = self.show_func_ui
bl_socket.show_info_columns = self.show_info_columns bl_socket.show_info_columns = self.show_info_columns
bl_socket.show_name_selector = self.show_name_selector
# Info Draw # Info Draw
bl_socket.use_info_draw = True bl_socket.use_info_draw = True

View File

@ -389,14 +389,15 @@ class BLField:
Reset by setting the descriptor to `Signal.ResetStrSearch`. Reset by setting the descriptor to `Signal.ResetStrSearch`.
""" """
cached_items = self.bl_prop_str_search.read_nonpersist(_self) return self.str_cb(_self, context, edit_text)
if cached_items is not Signal.CacheNotReady: # cached_items = self.bl_prop_str_search.read_nonpersist(_self)
if cached_items is Signal.CacheEmpty: # if cached_items is not Signal.CacheNotReady:
computed_items = self.str_cb(_self, context, edit_text) # if cached_items is Signal.CacheEmpty:
self.bl_prop_str_search.write_nonpersist(_self, computed_items) # computed_items = self.str_cb(_self, context, edit_text)
return computed_items # self.bl_prop_str_search.write_nonpersist(_self, computed_items)
return cached_items # return computed_items
return [] # return cached_items
# return []
def safe_enum_cb( def safe_enum_cb(
self, _self: bl_instance.BLInstance, context: bpy.types.Context self, _self: bl_instance.BLInstance, context: bpy.types.Context

View File

@ -28,11 +28,13 @@ Attributes:
import enum import enum
import functools import functools
import sys
import typing as typ import typing as typ
from fractions import Fraction from fractions import Fraction
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxtyping as jtyp
import pydantic as pyd import pydantic as pyd
import sympy as sp import sympy as sp
import sympy.physics.units as spu import sympy.physics.units as spu
@ -144,6 +146,21 @@ class MathType(enum.StrEnum):
complex: MathType.Complex, complex: MathType.Complex,
}[dtype] }[dtype]
@staticmethod
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
"""Deduce the MathType corresponding to a JAX array.
We go about this by leveraging that:
- `data` is of a homogeneous type.
- `data.item(0)` returns a single element of the array w/pure-python type.
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
Notes:
Should also work with numpy arrays.
"""
return MathType.from_pytype(type(data.item(0)))
@staticmethod @staticmethod
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None: def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
if isinstance(obj, bool | int | Fraction | float | complex): if isinstance(obj, bool | int | Fraction | float | complex):
@ -173,6 +190,39 @@ class MathType(enum.StrEnum):
MT.Complex: sp.Complexes, MT.Complex: sp.Complexes,
}[self] }[self]
@property
def inf_finite(self) -> type:
"""Opinionated finite representation of "infinity" within this `MathType`.
These are chosen using `sys.maxsize` and `sys.float_info`.
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
Notes:
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
"""
MT = MathType
Z = MT.Integer
R = MT.Integer
return {
MT.Integer: (-sys.maxsize, sys.maxsize),
MT.Rational: (
Fraction(Z.inf_finite[0], 1),
Fraction(Z.inf_finite[1], 1),
),
MT.Real: -(sys.float_info.min, sys.float_info.max),
MT.Complex: (
complex(R.inf_finite[0], R.inf_finite[0]),
complex(R.inf_finite[1], R.inf_finite[1]),
),
}[self]
@property @property
def sp_symbol_a(self) -> type: def sp_symbol_a(self) -> type:
MT = MathType MT = MathType
@ -192,6 +242,10 @@ class MathType(enum.StrEnum):
MathType.Complex: '', MathType.Complex: '',
}[value] }[value]
@property
def label_pretty(self) -> str:
return MathType.to_str(self)
@staticmethod @staticmethod
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
return MathType.to_str(value) return MathType.to_str(value)
@ -819,14 +873,15 @@ def sp_to_str(sp_obj: SympyExpr) -> str:
A string representing the expression for human use. A string representing the expression for human use.
_The string is not re-encodable to the expression._ _The string is not re-encodable to the expression._
""" """
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj) return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
def pretty_symbol(sym: sp.Symbol) -> str: def pretty_symbol(sym: sp.Symbol) -> str:
return f'{sym.name}' + ( return f'{sym.name}' + (
'' ''
if sym.is_complex if sym.is_integer
else ('' if sym.is_real else ('' if sym.is_integer else '?')) else ('' if sym.is_real else ('' if sym.is_complex else '?'))
) )
@ -1039,20 +1094,24 @@ class PhysicalType(enum.StrEnum):
PT.LumIntensity: spu.candela, PT.LumIntensity: spu.candela,
PT.LumFlux: spu.candela * spu.steradian, PT.LumFlux: spu.candela * spu.steradian,
PT.Illuminance: spu.candela / spu.meter**2, PT.Illuminance: spu.candela / spu.meter**2,
# Optics
PT.OrdinaryWaveVector: terahertz,
PT.AngularWaveVector: spu.radian * terahertz,
}[self] }[self]
@functools.cached_property @functools.cached_property
def valid_units(self) -> list[Unit]: def valid_units(self) -> list[Unit]:
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
Notes:
The order in which valid units are declared is the exact same order that UI dropdowns display them.
**Altering the order of units breaks backwards compatibility**.
"""
PT = PhysicalType PT = PhysicalType
return { return {
PT.NonPhysical: [None], PT.NonPhysical: [None],
# Global # Global
PT.Time: [ PT.Time: [
femtosecond,
spu.picosecond, spu.picosecond,
femtosecond,
spu.nanosecond, spu.nanosecond,
spu.microsecond, spu.microsecond,
spu.millisecond, spu.millisecond,
@ -1070,11 +1129,11 @@ class PhysicalType(enum.StrEnum):
], ],
PT.Freq: ( PT.Freq: (
_valid_freqs := [ _valid_freqs := [
terahertz,
spu.hertz, spu.hertz,
kilohertz, kilohertz,
megahertz, megahertz,
gigahertz, gigahertz,
terahertz,
petahertz, petahertz,
exahertz, exahertz,
] ]
@ -1083,10 +1142,10 @@ class PhysicalType(enum.StrEnum):
# Cartesian # Cartesian
PT.Length: ( PT.Length: (
_valid_lens := [ _valid_lens := [
spu.micrometer,
spu.nanometer,
spu.picometer, spu.picometer,
spu.angstrom, spu.angstrom,
spu.nanometer,
spu.micrometer,
spu.millimeter, spu.millimeter,
spu.centimeter, spu.centimeter,
spu.meter, spu.meter,
@ -1102,24 +1161,24 @@ class PhysicalType(enum.StrEnum):
PT.Vel: [_unit / spu.second for _unit in _valid_lens], PT.Vel: [_unit / spu.second for _unit in _valid_lens],
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens], PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
PT.Mass: [ PT.Mass: [
spu.kilogram,
spu.electron_rest_mass, spu.electron_rest_mass,
spu.dalton, spu.dalton,
spu.microgram, spu.microgram,
spu.milligram, spu.milligram,
spu.gram, spu.gram,
spu.kilogram,
spu.metric_ton, spu.metric_ton,
], ],
PT.Force: [ PT.Force: [
spu.kg * spu.meter / spu.second**2,
nanonewton,
micronewton, micronewton,
nanonewton,
millinewton, millinewton,
spu.newton, spu.newton,
spu.kg * spu.meter / spu.second**2,
], ],
PT.Pressure: [ PT.Pressure: [
millibar,
spu.bar, spu.bar,
millibar,
spu.pascal, spu.pascal,
hectopascal, hectopascal,
spu.atmosphere, spu.atmosphere,
@ -1129,8 +1188,8 @@ class PhysicalType(enum.StrEnum):
], ],
# Energy # Energy
PT.Work: [ PT.Work: [
spu.electronvolt,
spu.joule, spu.joule,
spu.electronvolt,
], ],
PT.Power: [ PT.Power: [
spu.watt, spu.watt,
@ -1194,18 +1253,17 @@ class PhysicalType(enum.StrEnum):
PT.Illuminance: [ PT.Illuminance: [
spu.candela / spu.meter**2, spu.candela / spu.meter**2,
], ],
# Optics
PT.OrdinaryWaveVector: _valid_freqs,
PT.AngularWaveVector: [spu.radian * _unit for _unit in _valid_freqs],
}[self] }[self]
@staticmethod @staticmethod
def from_unit(unit: Unit) -> list[Unit]: def from_unit(unit: Unit, optional: bool = False) -> list[Unit] | None:
for physical_type in list(PhysicalType): for physical_type in list(PhysicalType):
if unit in physical_type.valid_units: if unit in physical_type.valid_units:
return physical_type return physical_type
## TODO: Optimize ## TODO: Optimize
if optional:
return None
msg = f'Could not determine PhysicalType for {unit}' msg = f'Could not determine PhysicalType for {unit}'
raise ValueError(msg) raise ValueError(msg)
@ -1400,20 +1458,9 @@ def sympy_to_python(
#################### ####################
# - Convert to Unit System # - Convert to Unit System
#################### ####################
def convert_to_unit_system( def strip_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None sp_obj: SympyExpr, unit_system: UnitSystem | None = None
) -> SympyExpr: ) -> SympyExpr:
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression. """Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
@ -1427,6 +1474,19 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> Symp
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def scale_to_unit_system( def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array: ) -> int | float | complex | tuple | jax.Array:

View File

@ -18,7 +18,6 @@
import enum import enum
import functools import functools
import time
import typing as typ import typing as typ
import jax import jax
@ -34,7 +33,7 @@ import seaborn as sns
from blender_maxwell import contracts as ct from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
mplstyle.use('fast') ## TODO: Does this do anything? # mplstyle.use('fast') ## TODO: Does this do anything?
sns.set_theme() sns.set_theme()
log = logger.get(__name__) log = logger.get(__name__)
@ -59,6 +58,9 @@ class Colormap(enum.StrEnum):
Viridis = enum.auto() Viridis = enum.auto()
Grayscale = enum.auto() Grayscale = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
return { return {
@ -139,7 +141,9 @@ def rgba_image_from_2d_map(
#################### ####################
@functools.lru_cache(maxsize=16) @functools.lru_cache(maxsize=16)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int): def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure(figsize=[width_inches, height_inches], dpi=dpi) fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
)
canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig) canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
ax = fig.add_subplot() ax = fig.add_subplot()
@ -152,66 +156,53 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
# - Plotters # - Plotters
#################### ####################
# () -> # () ->
def plot_box_plot_1d( def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis x_sym, y_sym = list(data.keys())
) -> None:
x_sym = info.last_dim
y_sym = info.output
ax.boxplot([data]) ax.boxplot([data[y_sym]])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: def plot_bar(data, ax: mpl_ax.Axis) -> None:
x_sym = info.last_dim x_sym, heights_sym = list(data.keys())
y_sym = info.output
p = ax.bar(info.dims[x_sym], data) p = ax.bar(data[x_sym], data[heights_sym])
ax.bar_label(p, label_type='center') ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(heights_sym.plot_label)
# () -> # () ->
def plot_curve_2d( def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis x_sym, y_sym = list(data.keys())
) -> None:
x_sym = info.last_dim
y_sym = info.output
ax.plot(info.dims[x_sym].realize_array.values, data) ax.plot(data[x_sym], data[y_sym])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(y_sym.plot_label)
def plot_points_2d( def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis x_sym, y_sym = list(data.keys())
) -> None:
x_sym = info.last_dim
y_sym = info.output
ax.scatter(x_sym.realize_array.values, data) ax.scatter(data[x_sym], data[y_sym])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(y_sym.plot_label)
# (, ) -> # (, ) ->
def plot_curves_2d( def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis x_sym, label_sym, y_sym = list(data.keys())
) -> None:
x_sym = info.first_dim
y_sym = info.output
for i, category in enumerate(info.dims[info.last_dim]): for i, label in enumerate(data[label_sym]):
ax.plot(info.dims[x_sym], data[:, i], label=category) ax.plot(data[x_sym], data[y_sym][:, i], label=label)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(y_sym.plot_label)
ax.legend() ax.legend()
@ -220,12 +211,10 @@ def plot_curves_2d(
def plot_filled_curves_2d( def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_sym = info.first_dim x_sym, _, y_sym = list(data.keys())
y_sym = info.output
shared_x_idx = info.dims[info.last_dim] ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1]) ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label) ax.set_xlabel(y_sym.plot_label)
ax.legend() ax.legend()
@ -235,11 +224,9 @@ def plot_filled_curves_2d(
def plot_heatmap_2d( def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_sym = info.first_dim x_sym, y_sym, c_sym = list(data.keys())
y_sym = info.last_dim
c_sym = info.output
heatmap = ax.imshow(data, aspect='equal', interpolation='none') heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
ax.figure.colorbar(heatmap, cax=ax) ax.figure.colorbar(heatmap, cax=ax)
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}') ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')

View File

@ -99,6 +99,7 @@ class TypeID(enum.StrEnum):
SympyType: str = '!type=sympytype' SympyType: str = '!type=sympytype'
SympyExpr: str = '!type=sympyexpr' SympyExpr: str = '!type=sympyexpr'
SocketDef: str = '!type=socketdef' SocketDef: str = '!type=socketdef'
SimSymbol: str = '!type=simsymbol'
ManagedObj: str = '!type=managedobj' ManagedObj: str = '!type=managedobj'
@ -161,11 +162,12 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL) return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
if hasattr(_type, 'parse_as_msgspec') and ( if hasattr(_type, 'parse_as_msgspec') and (
is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj] is_representation(obj)
and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj, TypeID.SimSymbol]
): ):
return _type.parse_as_msgspec(obj) return _type.parse_as_msgspec(obj)
msg = f'Can\'t decode "{obj}" to type {type(obj)}' msg = f'can\'t decode "{obj}" to type {type(obj)}'
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -14,36 +14,60 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import enum import enum
import functools
import string
import sys import sys
import typing as typ import typing as typ
from fractions import Fraction from fractions import Fraction
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp import sympy as sp
import sympy.physics.units as spu
from . import extra_sympy_units as spux from . import extra_sympy_units as spux
from . import logger, serialize
int_min = -(2**64) int_min = -(2**64)
int_max = 2**64 int_max = 2**64
float_min = sys.float_info.min float_min = sys.float_info.min
float_max = sys.float_info.max float_max = sys.float_info.max
log = logger.get(__name__)
def unicode_superscript(n: int) -> str:
"""Transform an integer into its unicode-based superscript character."""
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
#################### ####################
# - Simulation Symbol Names # - Simulation Symbol Names
#################### ####################
_l = ''
_it_lower = iter(string.ascii_lowercase)
class SimSymbolName(enum.StrEnum): class SimSymbolName(enum.StrEnum):
# Lower # Generic
LowerA = enum.auto() Constant = enum.auto()
LowerB = enum.auto() Expr = enum.auto()
LowerC = enum.auto() Data = enum.auto()
LowerD = enum.auto()
LowerI = enum.auto() # Ascii Letters
LowerT = enum.auto() while True:
LowerX = enum.auto() try:
LowerY = enum.auto() globals()['_l'] = next(globals()['_it_lower'])
LowerZ = enum.auto() except StopIteration:
break
locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
# Greek Letters
LowerTheta = enum.auto()
LowerPhi = enum.auto()
# Fields # Fields
Ex = enum.auto() Ex = enum.auto()
@ -64,18 +88,15 @@ class SimSymbolName(enum.StrEnum):
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
Flux = enum.auto()
PermXX = enum.auto() PermXX = enum.auto()
PermYY = enum.auto() PermYY = enum.auto()
PermZZ = enum.auto() PermZZ = enum.auto()
Flux = enum.auto()
DiffOrderX = enum.auto() DiffOrderX = enum.auto()
DiffOrderY = enum.auto() DiffOrderY = enum.auto()
# Generic
Expr = enum.auto()
#################### ####################
# - UI # - UI
#################### ####################
@ -109,49 +130,66 @@ class SimSymbolName(enum.StrEnum):
@property @property
def name(self) -> str: def name(self) -> str:
SSN = SimSymbolName SSN = SimSymbolName
return { return (
# Lower # Ascii Letters
SSN.LowerA: 'a', {SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
SSN.LowerB: 'b', | {
SSN.LowerC: 'c', SSN[f'Upper{letter.upper()}']: letter.upper()
SSN.LowerD: 'd', for letter in string.ascii_lowercase
SSN.LowerI: 'i', }
SSN.LowerT: 't', | {
SSN.LowerX: 'x', # Generic
SSN.LowerY: 'y', SSN.Constant: 'constant',
SSN.LowerZ: 'z', SSN.Expr: 'expr',
# Fields SSN.Data: 'data',
SSN.Ex: 'Ex', # Greek Letters
SSN.Ey: 'Ey', SSN.LowerTheta: 'theta',
SSN.Ez: 'Ez', SSN.LowerPhi: 'phi',
SSN.Hx: 'Hx', # Fields
SSN.Hy: 'Hy', SSN.Ex: 'Ex',
SSN.Hz: 'Hz', SSN.Ey: 'Ey',
SSN.Er: 'Ex', SSN.Ez: 'Ez',
SSN.Etheta: 'Ey', SSN.Hx: 'Hx',
SSN.Ephi: 'Ez', SSN.Hy: 'Hy',
SSN.Hr: 'Hx', SSN.Hz: 'Hz',
SSN.Htheta: 'Hy', SSN.Er: 'Ex',
SSN.Hphi: 'Hz', SSN.Etheta: 'Ey',
# Optics SSN.Ephi: 'Ez',
SSN.Wavelength: 'wl', SSN.Hr: 'Hx',
SSN.Frequency: 'freq', SSN.Htheta: 'Hy',
SSN.Flux: 'flux', SSN.Hphi: 'Hz',
SSN.PermXX: 'eps_xx', # Optics
SSN.PermYY: 'eps_yy', SSN.Wavelength: 'wl',
SSN.PermZZ: 'eps_zz', SSN.Frequency: 'freq',
SSN.DiffOrderX: 'order_x', SSN.PermXX: 'eps_xx',
SSN.DiffOrderY: 'order_y', SSN.PermYY: 'eps_yy',
# Generic SSN.PermZZ: 'eps_zz',
SSN.Expr: 'expr', SSN.Flux: 'flux',
}[self] SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
}
)[self]
@property @property
def name_pretty(self) -> str: def name_pretty(self) -> str:
SSN = SimSymbolName SSN = SimSymbolName
return { return {
# Generic
# Greek Letters
SSN.LowerTheta: 'θ',
SSN.LowerPhi: 'φ',
# Fields
SSN.Etheta: '',
SSN.Ephi: '',
SSN.Hr: 'Hr',
SSN.Htheta: '',
SSN.Hphi: '',
# Optics
SSN.Wavelength: 'λ', SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓', SSN.Frequency: '𝑓',
SSN.PermXX: 'ε_xx',
SSN.PermYY: 'ε_yy',
SSN.PermZZ: 'ε_zz',
}.get(self, self.name) }.get(self, self.name)
@ -173,8 +211,7 @@ def mk_interval(
) )
@dataclasses.dataclass(kw_only=True, frozen=True) class SimSymbol(pyd.BaseModel):
class SimSymbol:
"""A declarative representation of a symbolic variable. """A declarative representation of a symbolic variable.
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof. `sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
@ -183,6 +220,8 @@ class SimSymbol:
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
""" """
model_config = pyd.ConfigDict(frozen=True)
sym_name: SimSymbolName sym_name: SimSymbolName
mathtype: spux.MathType = spux.MathType.Real mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
@ -191,6 +230,9 @@ class SimSymbol:
## -> 'None' indicates that no particular unit has yet been chosen. ## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way. ## -> Not exposed in the UI; must be set some other way.
unit: spux.Unit | None = None unit: spux.Unit | None = None
## -> TODO: We currently allowing units that don't match PhysicalType
## -> -- In particular, NonPhysical w/units means "unknown units".
## -> -- This is essential for the Scientific Constant Node.
# Size # Size
## -> All SimSymbol sizes are "2D", but interpreted by convention. ## -> All SimSymbol sizes are "2D", but interpreted by convention.
@ -205,43 +247,76 @@ class SimSymbol:
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1]. ## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain. ## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI. ## -> We have to deconstruct symbolic interval semantics a bit for UI.
is_constant: bool = False
interval_finite_z: tuple[int, int] = (0, 1) interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1)) interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0, 1) interval_finite_re: tuple[float, float] = (0.0, 1.0)
interval_inf: tuple[bool, bool] = (True, True) interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False) interval_closed: tuple[bool, bool] = (False, False)
interval_finite_im: tuple[float, float] = (0, 1) interval_finite_im: tuple[float, float] = (0.0, 1.0)
interval_inf_im: tuple[bool, bool] = (True, True) interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False) interval_closed_im: tuple[bool, bool] = (False, False)
#################### ####################
# - Properties # - Labels
#################### ####################
@property @functools.cached_property
def name(self) -> str: def name(self) -> str:
"""Usable name for the symbol.""" """Usable name for the symbol."""
return self.sym_name.name return self.sym_name.name
@property @functools.cached_property
def name_pretty(self) -> str: def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing.""" """Pretty (possibly unicode) name for the thing."""
return self.sym_name.name_pretty return self.sym_name.name_pretty
## TODO: Formatting conventions for bolding/etc. of vectors/mats/... ## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
@property @functools.cached_property
def mathtype_size_label(self) -> str:
"""Pretty label that shows both mathtype and size."""
return f'{self.mathtype.label_pretty}' + (
'ˣ'.join([unicode_superscript(out_axis) for out_axis in self.shape])
if self.shape
else ''
)
@functools.cached_property
def unit_label(self) -> str:
"""Pretty unit label, which is an empty string when there is no unit."""
return spux.sp_to_str(self.unit) if self.unit is not None else ''
@functools.cached_property
def def_label(self) -> str:
"""Pretty definition label, exposing the symbol definition."""
return f'{self.name_pretty} | {self.unit_label}{self.mathtype_size_label}'
## TODO: Domain of validity from self.domain?
@functools.cached_property
def plot_label(self) -> str: def plot_label(self) -> str:
"""Pretty plot-oriented label.""" """Pretty plot-oriented label."""
return f'{self.name_pretty}' + ( return f'{self.name_pretty}' + (
f'({self.unit})' if self.unit is not None else '' f'({self.unit})' if self.unit is not None else ''
) )
@property ####################
# - Computed Properties
####################
@functools.cached_property
def unit_factor(self) -> spux.SympyExpr: def unit_factor(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking.""" """Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return self.unit if self.unit is not None else sp.S(1) return self.unit if self.unit is not None else sp.S(1)
@property @functools.cached_property
def size(self) -> tuple[int, ...] | None:
return {
(1, 1): spux.NumberSize1D.Scalar,
(2, 1): spux.NumberSize1D.Vec2,
(3, 1): spux.NumberSize1D.Vec3,
(4, 1): spux.NumberSize1D.Vec4,
}.get((self.rows, self.cols))
@functools.cached_property
def shape(self) -> tuple[int, ...]: def shape(self) -> tuple[int, ...]:
match (self.rows, self.cols): match (self.rows, self.cols):
case (1, 1): case (1, 1):
@ -253,7 +328,12 @@ class SimSymbol:
case (_, _): case (_, _):
return (self.rows, self.cols) return (self.rows, self.cols)
@property @functools.cached_property
def shape_len(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return len(self.shape)
@functools.cached_property
def domain(self) -> sp.Interval | sp.Set: def domain(self) -> sp.Interval | sp.Set:
"""Return the scalar domain of valid values for each element of the symbol. """Return the scalar domain of valid values for each element of the symbol.
@ -303,11 +383,31 @@ class SimSymbol:
), ),
) )
@functools.cached_property
def valid_domain_value(self) -> spux.SympyExpr:
"""A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
match (self.domain.start.is_finite, self.domain.end.is_finite):
case (True, True):
if self.mathtype is spux.MathType.Integer:
return (self.domain.start + self.domain.end) // 2
return (self.domain.start + self.domain.end) / 2
case (True, False):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.start + one
case (False, True):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.end - one
case (False, False):
return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
#################### ####################
# - Properties # - Properties
#################### ####################
@property @functools.cached_property
def sp_symbol(self) -> sp.Symbol: def sp_symbol(self) -> sp.Symbol | sp.ImmutableMatrix:
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`. """Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined. As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
@ -352,7 +452,82 @@ class SimSymbol:
elif self.domain.right <= 0: elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True} mathtype_kwargs |= {'negative': True}
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor # Scalar: Return Symbol
if self.rows == 1 and self.cols == 1:
return sp.Symbol(self.sym_name.name, **mathtype_kwargs)
# Vector|Matrix: Return Matrix of Symbols
## -> MatrixSymbol doesn't support assumptions.
## -> This little construction does.
return sp.ImmutableMatrix(
[
[
sp.Symbol(self.sym_name.name + f'_{row}{col}', **mathtype_kwargs)
for col in range(self.cols)
]
for row in range(self.rows)
]
)
@functools.cached_property
def sp_symbol_matsym(self) -> sp.Symbol | sp.MatrixSymbol:
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`, w/variable shape support.
To preserve as many assumptions as possible, `self.sp_symbol` returns a matrix of individual `sp.Symbol`s whenever the `SimSymbol` is non-scalar.
However, this isn't always the most useful representation: For example, if the intention is to use a shaped symbolic variable as an argument to `sympy.lambdify()`, one would have to flatten each individual `sp.Symbol` and pass each matrix element as a single element, greatly complicating things like broadcasting.
For this reason, this property is provided.
Whenever the `SimSymbol` is scalar, it works identically to `self.sp_symbol`.
However, when the `SimSymbol` is shaped, an appropriate `sp.MatrixSymbol` is returned instead.
Notes:
`sp.MatrixSymbol` doesn't support assumptions.
As such, things like deduction of `MathType` from expressions involving a matrix symbol simply won't work.
"""
if self.shape_len == 0:
return self.sp_symbol
return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
@functools.cached_property
def sp_symbol_phy(self) -> spux.SympyExpr:
"""Physical symbol containing `self.sp_symbol` multiplied by `self.unit`."""
return self.sp_symbol * self.unit_factor
@functools.cached_property
def expr_info(self) -> dict[str, typ.Any]:
"""Generate keyword arguments for an ExprSocket, whose output values will be guaranteed to conform to this `SimSymbol`.
Notes:
Before use, `active_kind=ct.FlowKind.Range` can be added to make the `ExprSocket`.
Default values are set for both `Value` and `Range`.
To this end, `self.domain` is used.
Since `ExprSocketDef` allows the use of infinite bounds for `default_min` and `default_max`, we defer the decision of how to treat finite-fallback to the `ExprSocketDef`.
"""
if self.size is not None:
if self.unit in self.physical_type.valid_units:
return {
'output_name': self.sym_name,
# Socket Interface
'size': self.size,
'mathtype': self.mathtype,
'physical_type': self.physical_type,
# Defaults: Units
'default_unit': self.unit,
'default_symbols': [],
# Defaults: FlowKind.Value
'default_value': self.conform(
self.valid_domain_value, strip_unit=True
),
# Defaults: FlowKind.Range
'default_min': self.domain.start,
'default_max': self.domain.end,
}
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
raise NotImplementedError(msg)
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
raise NotImplementedError(msg)
#################### ####################
# - Operations # - Operations
@ -373,7 +548,7 @@ class SimSymbol:
cols=get_attr('cols'), cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'), interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'), interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_q'), interval_finite_re=get_attr('interval_finite_re'),
interval_inf=get_attr('interval_inf'), interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'), interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'), interval_finite_im=get_attr('interval_finite_im'),
@ -381,24 +556,199 @@ class SimSymbol:
interval_closed_im=get_attr('interval_closed_im'), interval_closed_im=get_attr('interval_closed_im'),
) )
def set_finite_domain( # noqa: PLR0913
self,
start: int | float,
end: int | float,
start_closed: bool = True,
end_closed: bool = True,
start_im: bool = float,
end_im: bool = float,
start_closed_im: bool = True,
end_closed_im: bool = True,
) -> typ.Self:
"""Update the symbol with a finite range."""
closed_re = (start_closed, end_closed)
closed_im = (start_closed_im, end_closed_im)
match self.mathtype:
case spux.MathType.Integer:
return self.update(
interval_finite_z=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Rational:
return self.update(
interval_finite_q=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Real:
return self.update(
interval_finite_re=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Complex:
return self.update(
interval_finite_re=(start, end),
interval_finite_im=(start_im, end_im),
interval_inf=(False, False),
interval_closed=closed_re,
interval_closed_im=closed_im,
)
def set_size(self, rows: int, cols: int) -> typ.Self: def set_size(self, rows: int, cols: int) -> typ.Self:
return self.update(rows=rows, cols=cols)
def conform(
self, sp_obj: spux.SympyType, strip_unit: bool = False
) -> spux.SympyType:
"""Conform a sympy object to the properties of this `SimSymbol`, if possible.
To achieve this, a number of operations may be performed:
- **Unit Conversion**: If the object has no units, but should, multiply by `self.unit`. If the object has units, but shouldn't, strip them. Otherwise, convert its unit to `self.unit`.
- **Broadcast Expansion**: If the object is a scalar, but the `SimSymbol` is shaped, then an `sp.ImmutableMatrix` is returned with the scalar at each position.
Returns:
A transformed sympy object guaranteed usable as a particular value of this `SimSymbol` variable.
Raises:
ValueError: If the units of `sp_obj` can't be cleanly converted to `self.unit`.
"""
res = sp_obj
# Unit Conversion
match (spux.uses_units(sp_obj), self.unit is not None):
case (True, True):
res = spux.scale_to_unit(sp_obj, self.unit) * self.unit
case (False, True):
res = sp_obj * self.unit
case (True, False):
res = spux.strip_unit_system(sp_obj)
if strip_unit:
res = spux.strip_unit_system(sp_obj)
# Broadcast Expansion
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
return res
def scale(
self, sp_obj: spux.SympyType, use_jax_array: bool = True
) -> int | float | complex | jtyp.Inexact[jtyp.Array, '...']:
"""Remove all symbolic elements from the conformed `sp_obj`, preparing it for use in contexts that don't support unrealized symbols.
On top of `self.conform()`, a number of operations are performed.
- **Unit Stripping**: The `self.unit` of the expression returned by `self.conform()` will be stripped.
- **Sympy to Python**: The now symbol-less expression will be converted to either a pure Python type, or to a `jax` array (if `use_jax_array` is set).
Notes:
When creating numerical functions of expressions using `.lambdify`, `self.scale()` **must be used** in place of `self.conform()` before the parameterized expression is used.
Returns:
A "raw" (pure Python / jax array) type guaranteed usable as a particular **numerical** value of this `SymSymbol` variable.
"""
# Conform
res = self.conform(sp_obj)
# Strip Units
res = spux.scale_to_unit(sp_obj, self.unit)
# Sympy to Python
res = spux.sympy_to_python(res, use_jax_array=use_jax_array)
return res # noqa: RET504
@staticmethod
def from_expr(
sym_name: SimSymbolName,
expr: spux.SympyExpr,
unit_expr: spux.SympyExpr,
) -> typ.Self:
"""Deduce a `SimSymbol` that matches the output of a given expression (and unit expression).
This is an essential method, allowing for the ded
Notes:
`PhysicalType` **cannot be set** from an expression in the generic sense.
Therefore, the trick of using `NonPhysical` with non-`None` unit to denote unknown `PhysicalType` is used in the output.
All intervals are kept at their defaults.
Parameters:
sym_name: The `SimSymbolName` to set to the resulting symbol.
expr: The unit-aware expression to parse and encapsulate as a symbol.
unit_expr: A dimensional analysis expression (set to `1` to make the resulting symbol unitless).
Fundamentally, units are just the variables of scalar terms.
'1' for unitless terms are, in the dimanyl sense, constants.
Doing it like this may be a little messy, but is accurate.
Returns:
A fresh new `SimSymbol` that tries to match the given expression (and unit expression) well enough to be usable in place of it.
"""
# MathType from Expr Assumptions
## -> All input symbols have assumptions, because we are very pedantic.
## -> Therefore, we should be able to reconstruct the MathType.
mathtype = spux.MathType.from_expr(expr)
# PhysicalType as "NonPhysical"
## -> 'unit' still applies - but we can't guarantee a PhysicalType will.
## -> Therefore, this is what we gotta do.
physical_type = spux.PhysicalType.NonPhysical
# Rows/Cols from Expr (if Matrix)
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
return SimSymbol( return SimSymbol(
sym_name=self.sym_name, sym_name=sym_name,
mathtype=self.mathtype, mathtype=mathtype,
physical_type=self.physical_type, physical_type=physical_type,
unit=self.unit, unit=unit_expr if unit_expr != 1 else None,
rows=rows, rows=rows,
cols=cols, cols=cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
) )
####################
# - Serialization
####################
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
"""Transforms this `SimSymbol` into an object that can be natively serialized by `msgspec`.
Notes:
Makes use of `pydantic.BaseModel.model_dump()` to cast any special fields into a serializable format.
If this method is failing, check that `pydantic` can actually cast all the fields in your model.
Returns:
A particular `list`, with two elements:
1. The `serialize`-provided "Type Identifier", to differentiate this list from generic list.
2. A dictionary containing simple Python types, as cast by `pydantic`.
"""
return [serialize.TypeID.SimSymbol, self.__class__.__name__, self.model_dump()]
@staticmethod
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
"""Transforms an object made by `self.dump_as_msgspec()` into an instance of `SimSymbol`.
Notes:
The method presumes that the deserialized object produced by `msgspec` perfectly matches the object originally created by `self.dump_as_msgspec()`.
This is a **mostly robust** presumption, as `pydantic` attempts to be quite consistent in how to interpret types with almost identical semantics.
Still, yet-unknown edge cases may challenge these presumptions.
Returns:
A new instance of `SimSymbol`, initialized using the `model_dump()` dictionary.
"""
return SimSymbol(**obj[2])
#################### ####################
# - Common Sim Symbols # - Common Sim Symbols
@ -453,14 +803,10 @@ class CommonSimSymbol(enum.StrEnum):
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
Flux = enum.auto() Flux = enum.auto()
WaveVecX = enum.auto() DiffOrderX = enum.auto()
WaveVecY = enum.auto() DiffOrderY = enum.auto()
WaveVecZ = enum.auto()
#################### ####################
# - UI # - UI
@ -549,10 +895,10 @@ class CommonSimSymbol(enum.StrEnum):
if eh == 'e' if eh == 'e'
else spux.PhysicalType.HField, else spux.PhysicalType.HField,
unit=unit, unit=unit,
interval_finite_re=(0, sys.float_info.max), interval_finite_re=(0, float_max),
interval_inf_re=(False, True), interval_inf_re=(False, True),
interval_closed_re=(True, False), interval_closed_re=(True, False),
interval_finite_im=(sys.float_info.min, sys.float_info.max), interval_finite_im=(float_min, float_max),
interval_inf_im=(True, True), interval_inf_im=(True, True),
) )
@ -575,7 +921,7 @@ class CommonSimSymbol(enum.StrEnum):
sym_name=self.name, sym_name=self.name,
physical_type=spux.PhysicalType.Time, physical_type=spux.PhysicalType.Time,
unit=unit, unit=unit,
interval_finite_re=(0, sys.float_info.max), interval_finite_re=(0, float_max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(True, False), interval_closed=(True, False),
), ),
@ -592,19 +938,13 @@ class CommonSimSymbol(enum.StrEnum):
CSS.FieldHr: sym_field('h'), CSS.FieldHr: sym_field('h'),
CSS.FieldHtheta: sym_field('h'), CSS.FieldHtheta: sym_field('h'),
CSS.FieldHphi: sym_field('h'), CSS.FieldHphi: sym_field('h'),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
# Optics # Optics
CSS.Wavelength: SimSymbol( CSS.Wavelength: SimSymbol(
sym_name=self.name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length, physical_type=spux.PhysicalType.Length,
unit=unit, unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, float_max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
@ -613,10 +953,30 @@ class CommonSimSymbol(enum.StrEnum):
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
unit=unit, unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, float_max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
CSS.DiffOrderX: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
interval_finite=(int_min, int_max),
interval_inf=(True, True),
interval_closed=(False, False),
),
CSS.DiffOrderY: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
interval_finite=(int_min, int_max),
interval_inf=(True, True),
interval_closed=(False, False),
),
}[self] }[self]