refactor: revamped symbolic flow (inaccurate unit conversions)

main
Sofus Albert Høgsbro Rose 2024-05-24 16:01:23 +02:00
parent 353a2c997e
commit bcba444a8b
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
28 changed files with 2397 additions and 1020 deletions

View File

@ -18,8 +18,8 @@ from .array import ArrayFlow
from .capabilities import CapabilitiesFlow
from .flow_kinds import FlowKind
from .info import InfoFlow
from .lazy_range import RangeFlow, ScalingMode
from .lazy_func import FuncFlow
from .lazy_range import RangeFlow, ScalingMode
from .params import ParamsFlow
from .value import ValueFlow

View File

@ -50,11 +50,6 @@ class ArrayFlow:
####################
# - Computed Properties
####################
@property
def is_symbolic(self) -> bool:
"""Always False, as ArrayFlows are never unrealized."""
return False
def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values)
@ -196,5 +191,11 @@ class ArrayFlow:
"""
return self.rescale(lambda v: v, new_unit=new_unit)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
raise NotImplementedError
def rescale_to_unit_system(self, unit_system: spux.UnitSystem | None) -> typ.Self:
if unit_system is None:
return self.values
return self.correct_unit(None).rescale(
lambda v: spux.scale_to_unit_system(v * self.unit, unit_system),
new_unit=spux.convert_to_unit_system(self.unit, unit_system),
)

View File

@ -16,6 +16,7 @@
import dataclasses
import typing as typ
from types import MappingProxyType
from ..socket_types import SocketType
from .flow_kinds import FlowKind
@ -25,6 +26,7 @@ from .flow_kinds import FlowKind
class CapabilitiesFlow:
socket_type: SocketType
active_kind: FlowKind
allow_out_to_in: dict[FlowKind, FlowKind] = dataclasses.field(default_factory=dict)
is_universal: bool = False
@ -40,7 +42,13 @@ class CapabilitiesFlow:
def is_compatible_with(self, other: typ.Self) -> bool:
return other.is_universal or (
self.socket_type == other.socket_type
and self.active_kind == other.active_kind
and (
self.active_kind == other.active_kind
or (
other.active_kind in other.allow_out_to_in
and self.active_kind == other.allow_out_to_in[other.active_kind]
)
)
# == Constraint
and all(
name in other.must_match

View File

@ -67,8 +67,9 @@ class InfoFlow:
default_factory=dict
)
# Access
@functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None:
def first_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
@ -87,13 +88,24 @@ class InfoFlow:
return list(self.dims.keys())[-1]
return None
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
if idx > 0 and idx < len(self.dims) - 1:
return list(self.dims.keys())[idx]
return None
def dim_by_name(self, dim_name: str) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
dims_with_name = [dim for dim in self.dims if dim.name == dim_name]
if len(dims_with_name) == 1:
return dims_with_name[0]
msg = f'Dim name {dim_name} not found in InfoFlow (or >1 found)'
raise ValueError(msg)
# Information By-Dim
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
@ -114,6 +126,23 @@ class InfoFlow:
return isinstance(self.dims[dim], list)
return False
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim has explicitly uniform indexing.
This is needed primarily to check whether a Fourier Transform can be meaningfully performed on the data over the dimension's axis.
In practice, we've decided that only `RangeFlow` really truly _guarantees_ uniform indexing.
While `ArrayFlow` may be uniform in practice, it's a very expensive to check, and it's far better to enforce that the user perform that check and opt for a `RangeFlow` instead, at the time of dimension definition.
"""
return isinstance(self.dims[dim], RangeFlow) and self.dims[dim].scaling == 'lin'
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
####################
# - Output: Contravariant Value
####################
@ -128,6 +157,49 @@ class InfoFlow:
default_factory=dict
)
####################
# - Properties
####################
@functools.cached_property
def input_mathtypes(self) -> tuple[spux.MathType, ...]:
return tuple([dim.mathtype for dim in self.dims])
@functools.cached_property
def output_mathtypes(self) -> tuple[spux.MathType, int, int]:
return [self.output.mathtype for _ in range(len(self.output.shape) + 1)]
@functools.cached_property
def order(self) -> tuple[spux.MathType, ...]:
r"""The order of the tensor represented by this info.
While that sounds fancy and all, it boils down to:
$$
\texttt{dims} + |\texttt{output}.\texttt{shape}|
$$
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
return len(self.input_mathtypes) + self.output_shape_len
####################
# - Properties
####################
@functools.cached_property
def dim_labels(self) -> dict[str, dict[str, str]]:
"""Return a dictionary mapping pretty dim names to information oriented for columnar information display."""
return {
dim.name_pretty: {
'length': str(len(dim_idx)) if dim_idx is not None else '',
'mathtype': dim.mathtype.label_pretty,
'unit': dim.unit_label,
}
for dim, dim_idx in self.dims.items()
}
####################
# - Operations: Dimensions
####################
@ -147,9 +219,11 @@ class InfoFlow:
"""Slice a dimensional array by-index along a particular dimension."""
return InfoFlow(
dims={
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else _dim
_dim: (
dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else dim_idx
)
for _dim, dim_idx in self.dims.items()
},
output=self.output,
@ -166,7 +240,7 @@ class InfoFlow:
return InfoFlow(
dims={
(new_dim if _dim == old_dim else _dim): (
new_dim_idx if _dim == old_dim else _dim
new_dim_idx if _dim == old_dim else dim_idx
)
for _dim, dim_idx in self.dims.items()
},
@ -235,6 +309,26 @@ class InfoFlow:
pinned_values=self.pinned_values,
)
def operate_output(
self,
other: typ.Self,
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
) -> spux.SympyExpr:
if self.dims == other.dims:
sym_name = sim_symbols.SimSymbolName.Expr
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
return InfoFlow(
dims=self.dims,
output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
pinned_values=self.pinned_values,
)
msg = f'InfoFlow: operate_output cannot be used when dimensions are not identical ({self.dims} | {other.dims}).'
raise ValueError(msg)
####################
# - Operations: Fold
####################

View File

@ -22,7 +22,7 @@ from types import MappingProxyType
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import logger, sim_symbols
from .params import ParamsFlow
@ -244,17 +244,10 @@ class FuncFlow:
"""
func: LazyFunction
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=list
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
default_factory=dict
)
## TODO: Use SimSymbol instead of the MathType|PT union.
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
## -- SimSymbol has more semantic meaning, including a name.
## -- If desired, SimSymbols could maybe even require a specific unit.
## It could greatly simplify a whole lot of pain associated with func_args.
supports_jax: bool = False
####################
@ -315,17 +308,18 @@ class FuncFlow:
def realize(
self,
params: ParamsFlow,
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> typ.Self:
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
*params.scaled_func_args(self.func_args, symbol_values),
*params.scaled_func_kwargs(self.func_args, symbol_values),
)
return self.func(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
*params.scaled_func_args(self.func_kwargs, symbol_values),
*params.scaled_func_kwargs(self.func_kwargs, symbol_values),
)
####################

View File

@ -18,17 +18,18 @@ import dataclasses
import enum
import functools
import typing as typ
from fractions import Fraction
from types import MappingProxyType
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .lazy_func import FuncFlow
log = logger.get(__name__)
@ -62,7 +63,7 @@ class ScalingMode(enum.StrEnum):
@dataclasses.dataclass(frozen=True, kw_only=True)
class RangeFlow:
r"""Represents a spaced array using symbolic boundary expressions.
r"""Represents a finite spaced array using symbolic boundary expressions.
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
@ -79,33 +80,76 @@ class RangeFlow:
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes:
start: An expression generating a scalar, unitless, complex value for the array's lower bound.
start: An expression representing the unitless part of the finite, scalar, complex value for the array's lower bound.
_Integer, rational, and real values are also supported._
stop: An expression generating a scalar, unitless, complex value for the array's upper bound.
start: An expression representing the unitless part of the finite, scalar, complex value for the array's upper bound.
_Integer, rational, and real values are also supported._
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
scaling: The method of distributing `step` values between the two endpoints.
Generally, the linear default is sufficient.
unit: The unit of the generated array values
unit: The unit to interpret the values as.
symbols: Set of variables from which `start` and/or `stop` are determined.
"""
start: spux.ScalarUnitlessComplexExpr
stop: spux.ScalarUnitlessComplexExpr
steps: int
steps: int = 0
scaling: ScalingMode = ScalingMode.Lin
unit: spux.Unit | None = None
symbols: frozenset[spux.Symbol] = frozenset()
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
# Helper Attributes
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
####################
# - Computed Properties
# - SimSymbol Interop
####################
@staticmethod
def from_sym(
sym: sim_symbols.SimSymbol,
steps: int = 50,
scaling: ScalingMode | str = ScalingMode.Lin,
) -> typ.Self:
if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
use_steps = 0
else:
use_steps = steps
return RangeFlow(
start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
steps=use_steps,
scaling=ScalingMode(scaling),
unit=sym.unit,
)
def to_sym(
self,
sym_name: sim_symbols.SimSymbolName,
) -> typ.Self:
physical_type = spux.PhysicalType.from_unit(self.unit, optional=True)
return sim_symbols.SimSymbol(
sym_name=sym_name,
mathtype=self.mathtype,
physical_type=(
physical_type
if physical_type is not None
else spux.PhysicalType.NonPhysical
),
unit=self.unit,
rows=1,
cols=1,
).set_domain(start=self.realize_start(), end=self.realize_end())
####################
# - Symbols
####################
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
The order is guaranteed to be **deterministic**.
@ -115,10 +159,21 @@ class RangeFlow:
"""
return sorted(self.symbols, key=lambda sym: sym.name)
@property
def is_symbolic(self) -> bool:
"""Whether the `RangeFlow` has unrealized symbols."""
return len(self.symbols) > 0
@functools.cached_property
def sorted_sp_symbols(self) -> list[spux.Symbol]:
"""Computes `sympy` symbols from `self.sorted_symbols`.
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol for sym in self.sorted_symbols]
####################
# - Properties
####################
@functools.cached_property
def unit_factor(self) -> spux.SympyExpr:
return self.unit if self.unit is not None else sp.S(1)
def __len__(self) -> int:
"""Compute the length of the array that would be realized.
@ -166,6 +221,14 @@ class RangeFlow:
####################
# - Methods
####################
@property
def ideal_midpoint(self) -> spux.SympyExpr:
return (self.stop + self.start) / 2
@property
def ideal_range(self) -> spux.SympyExpr:
return self.stop - self.start
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
@ -181,8 +244,8 @@ class RangeFlow:
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
new_start = rescale_func(new_pre_start * self.unit_factor)
new_stop = rescale_func(new_pre_stop * self.unit_factor)
return RangeFlow(
start=(
@ -204,6 +267,99 @@ class RangeFlow:
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
@functools.cached_property
def bound_fourier_transform(self):
r"""Treat this `RangeFlow` it as an axis along which a fourier transform is being performed, such that its bounds scale according to the Nyquist Limit.
# Sampling Theory
In general, the Fourier Transform is an operator that works on infinite, periodic, continuous functions.
In this context alone it is a ideal transform (in terms of information retention), one which degrades quite gracefully in the face of practicalities like windowing (allowing them to apply analytically to non-periodic functions too).
While often used to transform an axis of time to an axis of frequency, in general this transform "simply" extracts all repeating structures from a function.
This is illustrated beautifully in the way that the output unit becomes the reciprocal of the input unit, which is the theory underlying why we say that measurements recieved as a reciprocal unit are in a "reciprocal space" (also called "k-space").
The real world is not so nice, of course, and as such we must generally make do with the Discrete Fourier Transform.
Even with bounded discrete information, we can annoy many mathematicians by defining a DFT in such a way that "structure per thing" ($\frac{1}{\texttt{unit}}$) still makes sense to us (to them, maybe not).
A DFT can still only retain the information given to it, but so long as we have enough "original structure", any "repeating structure" should be extractable with sufficient clarity to be useful.
What "sufficient clarity" means is the basis for the entire field of "sampling theory".
The theoretical maximum for the "fineness of repetition" that is "noticeable" in the fourier-transformed of some data is characterized by a theoretical upper bound called the Nyquist Frequency / Limit, which ends up being half of the sampling rate.
Thus, to determine bounds on the data, use of the Nyquist Limit is generally a good starting point.
Of course, when the discrete data comes from a discretization of a continuous signal, information from higher frequencies might still affect the discrete results.
They do little else than cause havoc, though - best causing noise, and at worst causing structured artifacts (sometimes called "aliasing").
Some of the first innovations in sampling theory were related to "anti-aliasing" filters, whose sole purpose is to try to remove informational frequencies above the Nyquist Limit of the discrete sensor.
In FDTD simulation, we're generally already ahead when it comes to aliasing, since our field values come from an already-discrete process.
That is, unless we start overly "binning" (averaging over $n$ discrete observations); in this case, care should be taken to make sure that interesting results aren't merely unfortunately structured aliasing artifacts.
For more, see <https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem>
# Implementation
In practice, our goal in `RangeFlow` is to compute the bounds of the index array along the fourier-transformed data axis.
The reciprocal of the unit will be taken (when unitless, `1/1=`).
The raw Nyquist Limit $n_q$ will be used to bound the unitless part of the output as $[-n_q, n_q]$
Raises:
ValueError: If `self.scaling` is not linear, since the FT can only be performed on uniformly spaced data.
"""
if self.scaling is ScalingMode.Lin:
nyquist_limit = self.steps / self.ideal_range
# Return New Bounds w/Nyquist Theorem
## -> The Nyquist Limit describes "max repeated info per sample".
## -> Information can still record "faster" than the Nyquist Limit.
## -> It will just be either noise (best case), or banded artifacts.
## -> This is called "aliasing", and it's best to try and filter it.
## -> Sims generally "bin" yee cells, which sacrifices some NyqLim.
return RangeFlow(
start=-nyquist_limit,
stop=nyquist_limit,
scaling=self.scaling,
unit=1 / self.unit if self.unit is not None else None,
pre_fourier_ideal_midpoint=self.ideal_midpoint,
)
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
raise ValueError(msg)
@functools.cached_property
def bound_inv_fourier_transform(self):
r"""Treat this `RangeFlow` as an axis along which an inverse fourier transform is being performed, such that its Nyquist-limit bounds are transformed back into values along the original unit dimension.
See `self.bound_fourier_transform` for the theoretical concepts.
Notes:
**The discrete inverse fourier transform always centers its output at $0$**.
Of course, it's entirely probable that the original signal was not centered at $0$.
For this reason, when performing a Fourier transform, `self.bound_fourier_transform` sets a special variable, `self.pre_fourier_ideal_midpoint`.
When set, it will retain the `self.ideal_midpoint` around which both `self.start` and `self.stop` should be centered after an inverse FT.
If `self.pre_fourier_ideal_midpoint` is set, then it will be used as the midpoint of the output's `start`/`stop`.
Otherwise, $0$ will be used - in which case the user should themselves, manually, shift the output if needed.
"""
if self.scaling is ScalingMode.Lin:
orig_ideal_range = self.steps / self.ideal_range
orig_start_centered = -orig_ideal_range
orig_stop_centered = orig_ideal_range
orig_ideal_midpoint = (
self.pre_fourier_ideal_midpoint
if self.pre_fourier_ideal_midpoint is not None
else sp.S(0)
)
# Return New Bounds w/Inverse of Nyquist Theorem
return RangeFlow(
start=-orig_start_centered + orig_ideal_midpoint,
stop=orig_stop_centered + orig_ideal_midpoint,
scaling=self.scaling,
unit=1 / self.unit if self.unit is not None else None,
)
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
raise ValueError(msg)
####################
# - Exporters
####################
@ -237,15 +393,15 @@ class RangeFlow:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
The ordering of the symbols is identical to `self.sorted_symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
A function that generates a 1D numerical array equivalent to the range represented in this `RangeFlow`.
"""
# Compile JAX Functions for Start/End Expressions
## -> FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
start_jax = sp.lambdify(self.sorted_sp_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.sorted_sp_symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
@ -256,54 +412,80 @@ class RangeFlow:
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize_symbols(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
"""Realize **all** input symbols to the `RangeFlow`.
Parameters:
symbol_values: A scalar, unitless, complex `sympy` expression for each symbol defined in `self.symbols`.
Returns:
A dictionary directly usable in expression substitutions using `sp.Basic.subs()`.
"""
if self.symbols == set(symbol_values.keys()):
realized_syms = {}
for sym in self.sorted_symbols:
sym_value = symbol_values[sym]
# Sympy Expression
## -> We need to conform the expression to the SimSymbol.
## -> Mainly, this is
if (
isinstance(sym_value, spux.SympyType)
and not isinstance(sym_value, sp.MatrixBase)
and not spux.uses_units(sym_value)
):
v = sym.conform(sym_value)
else:
msg = f'RangeFlow: No realization support for symbolic value {sym_value} (type={type(sym_value)})'
raise NotImplementedError(msg)
realized_syms |= {sym: v}
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
def realize_start(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.start.subs(realized_symbols))
def realize_stop(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.stop.subs(realized_symbols))
def realize_step_size(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
raise NotImplementedError('Non-linear scaling mode not yet suported')
msg = 'Non-linear scaling mode not yet suported'
raise NotImplementedError(msg)
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
raw_step_size = (
self.realize_stop(symbol_values) - self.realize_start(symbol_values) + 1
) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
@ -311,7 +493,9 @@ class RangeFlow:
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
@ -324,15 +508,33 @@ class RangeFlow:
## TODO: Check symbol values for coverage.
return ArrayFlow(
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
values=self.as_func(
*[
spux.scale_to_unit_system(symbol_values[sym])
for sym in self.sorted_symbols
]
),
unit=self.unit,
is_sorted=True,
)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
"""Standardized access to `self.realize()` when there are no symbols."""
return self.realize()
"""Standardized access to `self.realize()` when there are no symbols.
Raises:
ValueError: If there are symbols defined in `self.symbols`.
"""
if not self.symbols:
return self.realize()
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
raise ValueError(msg)
@property
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
"""Alias for `realize_array.values`."""
return self.realize_array.values
def __getitem__(self, subscript: slice):
"""Implement indexing and slicing in a sane way.
@ -379,23 +581,14 @@ class RangeFlow:
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Corrected unit to %s',
self,
corrected_unit,
)
return RangeFlow(
start=self.start,
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
return RangeFlow(
start=self.start,
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
@ -410,11 +603,6 @@ class RangeFlow:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to unit %s',
self,
unit,
)
return RangeFlow(
start=spux.scale_to_unit(self.start * self.unit, unit),
stop=spux.scale_to_unit(self.stop * self.unit, unit),
@ -423,11 +611,18 @@ class RangeFlow:
unit=unit,
symbols=self.symbols,
)
return RangeFlow(
start=self.start * unit,
stop=self.stop * unit,
steps=self.steps,
scaling=self.scaling,
unit=unit,
symbols=self.symbols,
)
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
def rescale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
"""Replaces the units, **with** rescaling of the bounds.
Parameters:
@ -439,28 +634,11 @@ class RangeFlow:
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to new unit system (new unit = %s)',
self,
unit_system[spux.PhysicalType.from_unit(self.unit)],
)
return RangeFlow(
start=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
stop=spux.strip_unit_system(
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
unit_system,
),
steps=self.steps,
scaling=self.scaling,
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
symbols=self.symbols,
)
msg = (
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
return RangeFlow(
start=spux.scale_to_unit_system(self.start * self.unit, unit_system),
stop=spux.scale_to_unit_system(self.stop * self.unit, unit_system),
steps=self.steps,
scaling=self.scaling,
unit=spux.convert_to_unit_system(self.unit, unit_system),
symbols=self.symbols,
)
raise ValueError(msg)

View File

@ -17,15 +17,19 @@
import dataclasses
import functools
import typing as typ
from fractions import Fraction
from types import MappingProxyType
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .expr_info import ExprInfo
from .flow_kinds import FlowKind
from .lazy_range import RangeFlow
# from .info import InfoFlow
@ -34,13 +38,22 @@ log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
All symbols valid for use in the expression.
"""
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
####################
# - Symbols
####################
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
Returns:
@ -48,52 +61,179 @@ class ParamsFlow:
"""
return sorted(self.symbols, key=lambda sym: sym.name)
@functools.cached_property
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
"""Computes `sympy` symbols from `self.sorted_symbols`.
When the output is shaped, a single shaped symbol (`sp.MatrixSymbol`) is used to represent the symbolic name and shaping.
This choice is made due to `MatrixSymbol`'s compatibility with `.lambdify` JIT.
Returns:
All symbols valid for use in the expression.
"""
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
####################
# - JIT'ed Callables for Numerical Function Arguments
####################
def func_args_n(
self, target_syms: list[sim_symbols.SimSymbol]
) -> list[
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
]
]:
"""Callable functions for evaluating each `self.func_args` entry numerically.
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `target_syms`.
Notes:
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
This ensures conformance to the `SimSymbol` properties (like units), as well as adherance to a numerical type identity compatible with `sp.lambdify()`.
Parameters:
target_syms: `SimSymbol`s describing how a particular `ParamsFlow` function argument should be scaled when performing a purely numerical insertion.
"""
return [
sp.lambdify(
self.sorted_sp_symbols,
target_sym.conform(func_arg, strip_unit=True),
'jax',
)
for func_arg, target_sym in zip(self.func_args, target_syms, strict=True)
]
def func_kwargs_n(
self, target_syms: dict[str, sim_symbols.SimSymbol]
) -> dict[
str,
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
],
]:
"""Callable functions for evaluating each `self.func_kwargs` entry numerically.
The arguments of each function **must** be pre-treated using `SimSymbol.scale()`.
This ensures conformance to the `SimSymbol` properties, as well as adherance to a numerical type identity compatible with `sp.lambdify()`
"""
return {
func_arg_key: sp.lambdify(
self.sorted_sp_symbols,
target_syms[func_arg_key].scale(func_arg),
'jax',
)
for func_arg_key, func_arg in self.func_kwargs.items()
}
####################
# - Realization
####################
def realize_symbols(
self,
symbol_values: dict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}),
) -> dict[
sp.Symbol,
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
]:
"""Fully realize all symbols by assigning them a value.
Three kinds of values for `symbol_values` are supported, fundamentally:
- **Sympy Expression**: When the value is a sympy expression with units, the unit of the `SimSymbol` key which unit the value if converted to.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
- **Range**: When the value is a `RangeFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
- **Array**: When the value is an `ArrayFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
Returns:
A dictionary almost with `.subs()`, other than `jax` arrays.
"""
if set(self.symbols) == set(symbol_values.keys()):
realized_syms = {}
for sym in self.sorted_symbols:
sym_value = symbol_values[sym]
if isinstance(sym_value, spux.SympyType):
v = sym.scale(sym_value)
elif isinstance(sym_value, ArrayFlow | RangeFlow):
v = sym_value.rescale_to_unit(sym.unit).values
## NOTE: RangeFlow must not be symbolic.
else:
msg = f'No support for symbolic value {sym_value} (type={type(sym_value)})'
raise NotImplementedError(msg)
realized_syms |= {sym: v}
return realized_syms
msg = f'ParamsFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
####################
# - Realize Arguments
####################
def scaled_func_args(
self,
unit_system: spux.UnitSystem | None = None,
target_syms: list[sim_symbols.SimSymbol] = (),
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
):
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
) -> list[
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
"""Realize correctly conformed numerical arguments for `self.func_args`.
For all `arg`s in `self.func_args`, the following operations are performed.
Because we allow symbols to be used in `self.func_args`, producing a numerical value that can be passed directly to a `FuncFlow` becomes a two-step process:
Notes:
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
1. Conform Symbols: Arbitrary `sympy` expressions passed as `symbol_values` must first be conformed to match the ex. units of `SimSymbol`s found in `self.symbols`, before they can be used.
2. Conform Function Arguments: Arbitrary `sympy` expressions encoded in `self.func_args` must, **after** inserting the conformed numerical symbols, themselves be conformed to the expected ex. units of the function that they are to be used within.
**`ParamsFlow` doesn't contain information about the `SimSymbol`s that `self.func_args` are expected to conform to** (on purpose).
Therefore, the user is required to pass a `target_syms` with identical length to `self.func_args`, describing the `SimSymbol`s to conform the function arguments to.
Our implementation attempts to utilize simple, powerful primitives to accomplish this in roughly three steps:
1. **Realize Symbols**: Particular passed symbolic values `symbol_values`, which are arbitrary `sympy` expressions, are conformed to the definitions in `self.symbols` (ex. to match units), then cast to numerical values (pure Python / jax array).
2. **Lazy Function Arguments**: Stored function arguments `self.func_args`, which are arbitrary `sympy` expressions, are conformed to the definitions in `target_syms` (ex. to match units), then cast to numerical values (pure Python / jax array).
_Technically, this happens as part of `self.func_args_n`._
3. **Numerical Evaluation**: The numerical values for each symbol are passed as parameters to each (callable) element of `self.func_args_n`, which produces a correct numerical value for each function argument.
Parameters:
target_syms: `SimSymbol`s describing how the function arguments returned by this method are intended to be used.
**Generally**, the parallel `FuncFlow.func_args` should be inserted here, and guarantees correct results when this output is inserted into `FuncFlow.func(...)`.
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `target_syms`).
"""
if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
realized_symbols = list(self.realize_symbols(symbol_values).values())
return [
(
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
if arg not in symbol_values
else symbol_values[arg]
)
for arg in self.func_args
func_arg_n(*realized_symbols)
for func_arg_n in self.func_args_n(target_syms)
]
def scaled_func_kwargs(
self,
unit_system: spux.UnitSystem | None = None,
target_syms: list[sim_symbols.SimSymbol] = (),
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
) -> dict[
str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
"""Realize correctly conformed numerical arguments for `self.func_kwargs`.
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
"""
realized_symbols = self.realize_symbols(symbol_values)
return {
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
if arg not in symbol_values
else symbol_values[arg]
for arg_name, arg in self.func_kwargs.items()
func_arg_name: func_arg_n(**realized_symbols)
for func_arg_name, func_arg_n in self.func_kwargs_n(target_syms).items()
}
####################
@ -129,8 +269,8 @@ class ParamsFlow:
####################
# - Generate ExprSocketDef
####################
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
def sym_expr_infos(self, use_range: bool = False) -> dict[str, ExprInfo]:
"""Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
@ -151,35 +291,22 @@ class ParamsFlow:
The `ExprInfo`s can be directly defererenced `**expr_info`)
"""
for sim_sym in self.sorted_symbols:
if use_range and sim_sym.mathtype is spux.MathType.Complex:
for sym in self.sorted_symbols:
if use_range and sym.mathtype is spux.MathType.Complex:
msg = 'No support for complex range in ExprInfo'
raise NotImplementedError(msg)
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
if use_range and (sym.rows > 1 or sym.cols > 1):
msg = 'No support for non-scalar elements of range in ExprInfo'
raise NotImplementedError(msg)
if sim_sym.rows > 3 or sim_sym.cols > 1:
if sym.rows > 3 or sym.cols > 1:
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
raise NotImplementedError(msg)
return {
sim_sym.name: {
# Declare Kind/Size
## -> Kind: Value prevents user-alteration of config.
## -> Size: Always scalar, since symbols are scalar (for now).
sym.name: {
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
'size': spux.NumberSize1D.Scalar,
# Declare MathType/PhysicalType
## -> MathType: Lookup symbol name in info dimensions.
## -> PhysicalType: Same.
'mathtype': self.dims[sim_sym].mathtype,
'physical_type': self.dims[sim_sym].physical_type,
# TODO: Default Value
# FlowKind.Value: Default Value
#'default_value':
# FlowKind.Range: Default Min/Max/Steps
'default_min': sim_sym.domain.start,
'default_max': sim_sym.domain.end,
'default_steps': 50,
}
for sim_sym in self.sorted_symbols
| sym.expr_info
for sym in self.sorted_symbols
}

View File

@ -29,6 +29,7 @@ import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow
@ -482,6 +483,93 @@ class DataFileFormat(enum.StrEnum):
## - When sidecars aren't found, the user would "fill in the blanks".
## - ...Thus achieving the same result as if there were a sidecar.
####################
# - Functions: DataFrame
####################
@staticmethod
def to_df(
data: jtyp.Shaped[jtyp.Array, 'x_size y_size'], info: InfoFlow
) -> pl.DataFrame:
"""Utility method to convert raw data to a `polars.DataFrame`, as guided by an `InfoFlow`.
Only works with 2D data (obviously).
Raises:
ValueError: If the data has more than two dimensions, all `info` dimensions are not discrete/labelled, or the dimensionality of `info` doesn't match.
"""
if info.order > 2: # noqa: PLR2004
msg = f'Data may not have more than two dimensions (info={info}, data.shape={data.shape})'
raise ValueError(msg)
if any(info.has_idx_cont(dim) for dim in info.dims):
msg = f'To convert data|info to a dataframe, no dimensions can have continuous indices (info={info})'
raise ValueError(msg)
data_np = np.array(data)
MT = spux.MathType
match (
info.input_mathtypes,
info.output.mathtype,
info.output.rows,
info.output.cols,
):
# (R,Z) -> Complex Scalar
## -> Polars (also pandas) doesn't have a complex type.
## -> Will be treated as (R, Z, 2) -> Real Scalar.
case ((MT.Rational | MT.Real, MT.Integer), MT.Complex, 1, 1):
row_dim = info.first_dim
col_dim = info.last_dim
return pl.DataFrame(
{row_dim.name: info.dims[row_dim]}
| {
col_label + postfix: re_im(data_np[:, col])
for col, col_label in enumerate(info.dims[col_dim])
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
}
)
# (R,Z) -> Scalar
case ((MT.Rational | MT.Real, MT.Integer), _, 1, 1):
row_dim = info.first_dim
col_dim = info.last_dim
return pl.DataFrame(
{row_dim.name: info.dims[row_dim]}
| {
col_label: data_np[:, col]
for col, col_label in enumerate(info.dims[col_dim])
}
)
# (Z) -> Complex Vector/Covector
case ((MT.Integer,), MT.Complex, r, c) if (r > 1 and c == 1) or (
r == 1 and c > 1
):
col_dim = info.last_dim
return pl.DataFrame(
{
col_label + postfix: re_im(data_np[col, :])
for col, col_label in enumerate(info.dims[col_dim])
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
}
)
# (Z) -> Real Vector
## -> Each integer index will be treated as a column index.
## -> This will effectively transpose the data.
case ((MT.Integer,), _, r, c) if (r > 1 and c == 1) or (r == 1 and c > 1):
col_dim = info.last_dim
return pl.DataFrame(
{
col_label: data_np[col, :]
for col, col_label in enumerate(info.dims[col_dim])
}
)
####################
# - Functions: Saver
####################
@ -506,50 +594,7 @@ class DataFileFormat(enum.StrEnum):
np.savetxt(path, data)
def save_csv(path, data, info):
data_np = np.array(data)
# Extract Input Coordinates
dim_columns = {
dim.name: np.array(dim_idx.realize_array)
for i, (dim, dim_idx) in enumerate(info.dims)
} ## TODO: realize_array might not be defined on some index arrays
# Declare Function to Extract Output Values
output_columns = {}
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
nonlocal output_columns
# Complex: Split to Two Columns
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
if bool(np.any(np.iscomplex(data_col))):
output_columns |= {
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
}
# Else: Use Array Directly
else:
output_columns |= {
f'{info.output.name}{output_idx_str}': data_col,
}
## TODO: Maybe a check to ensure dtype!=object?
# Extract Output Values
## -> 2D: Iterate over columns by-index.
## -> 1D: Declare the array as the only column.
if len(data_np.shape) == 2:
for output_idx in data_np.shape[1]:
declare_output_col(data_np[:, output_idx], output_idx, True)
else:
declare_output_col(data_np)
# Compute DataFrame & Write CSV
df = pl.DataFrame(dim_columns | output_columns)
log.debug('Writing Polars DataFrame to CSV:')
log.debug(df)
df = self.to_df(data, info)
df.write_csv(path)
def save_npy(path, data, info):

View File

@ -264,17 +264,22 @@ class ManagedBLImage(base.ManagedObj):
# times = [time.perf_counter()]
# Compute Plot Dimensions
aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
self.gen_image_geometry(width_inches, height_inches, dpi)
)
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
# self.gen_image_geometry(width_inches, height_inches, dpi)
# )
# times.append(['Image Geometry', time.perf_counter() - times[0]])
# log.critical(
# [aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px]
# )
# Create MPL Figure, Axes, and Compute Figure Geometry
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
_width_inches, _height_inches, _dpi
)
# fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
# _width_inches, _height_inches, _dpi
# )
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
# fig.clear()
ax.clear()
# times.append(['Clear Axis', time.perf_counter() - times[0]])

View File

@ -46,6 +46,7 @@ class FilterOperation(enum.StrEnum):
"""
# Slice
Slice = enum.auto()
SliceIdx = enum.auto()
# Pin
@ -53,9 +54,8 @@ class FilterOperation(enum.StrEnum):
Pin = enum.auto()
PinIdx = enum.auto()
# Reinterpret
# Dimension
Swap = enum.auto()
SetDim = enum.auto()
####################
# - UI
@ -65,14 +65,14 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation
return {
# Slice
FO.SliceIdx: 'a[...]',
FO.Slice: '=a[i:j]',
FO.SliceIdx: '≈a[v₁:v₂]',
# Pin
FO.PinLen1: 'pinₐ =1',
FO.PinLen1: 'pinₐ',
FO.Pin: 'pinₐ ≈v',
FO.PinIdx: 'pinₐ =a[v]',
FO.PinIdx: 'pinₐ =i',
# Reinterpret
FO.Swap: 'a₁ ↔ a₂',
FO.SetDim: 'setₐ =v',
}[value]
@staticmethod
@ -118,11 +118,6 @@ class FilterOperation(enum.StrEnum):
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
## SetDim
## -> There must be a dimension to correct.
if info.dims:
operations.append(FO.SetDim)
return operations
####################
@ -145,6 +140,7 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation
return {
# Slice
FO.Slice: 1,
FO.SliceIdx: 1,
# Pin
FO.PinLen1: 1,
@ -152,40 +148,35 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: 1,
# Reinterpret
FO.Swap: 2,
FO.SetDim: 1,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
match self:
case FO.SliceIdx | FO.Swap:
return info.dims
# Slice
case FO.Slice:
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
# PinLen1: Only allow dimensions with length=1.
case FO.SliceIdx:
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
# Pin
case FO.PinLen1:
return [
dim
for dim, dim_idx in info.dims.items()
if dim_idx is not None and len(dim_idx) == 1
if not info.has_idx_cont(dim) and len(dim_idx) == 1
]
# Pin: Only allow dimensions with discrete index.
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
case FO.Pin | FO.PinIdx:
return [
dim
for dim, dim_idx in info.dims
if dim_idx is not None and len(dim_idx) > 0
]
case FO.Pin:
return info.dims
case FO.SetDim:
return [
dim
for dim, dim_idx in info.dims
if dim_idx is not None
and not isinstance(dim_idx, list)
and dim_idx.mathtype == spux.MathType.Integer
]
case FO.PinIdx:
return [dim for dim in info.dims if not info.has_idx_cont(dim)]
# Dimension
case FO.Swap:
return info.dims
return []
@ -193,9 +184,14 @@ class FilterOperation(enum.StrEnum):
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or (
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info)
)
if self.num_dim_inputs == 1:
return dim_0 in self.valid_dims(info)
if self.num_dim_inputs == 2: # noqa: PLR2004
valid_dims = self.valid_dims(info)
return dim_0 in valid_dims and dim_1 in valid_dims
return False
####################
# - UI
@ -209,6 +205,9 @@ class FilterOperation(enum.StrEnum):
FO = FilterOperation
return {
# Pin
FO.Slice: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
@ -216,9 +215,8 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Reinterpret
# Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
FO.SetDim: lambda expr: expr,
}[self]
def transform_info(
@ -228,10 +226,10 @@ class FilterOperation(enum.StrEnum):
dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
):
FO = FilterOperation
return {
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
@ -239,7 +237,6 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
}[self]()
@ -330,8 +327,8 @@ class FilterMathNode(base.MaxwellSimNode):
def search_dims(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None and self.operation is not None:
return [
(dim_name, dim_name, dim_name, '', i)
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info))
(dim.name, dim.name_pretty, dim.name, '', i)
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
]
return []
@ -380,8 +377,6 @@ class FilterMathNode(base.MaxwellSimNode):
# Reinterpret
case FO.Swap:
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
case FO.SetDim:
return f'Filter: Set [{self.active_dim_0}]'
case _:
return self.bl_label
@ -480,30 +475,6 @@ class FilterMathNode(base.MaxwellSimNode):
)
}
# Loose Sockets: Set Dim
## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis.
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Dim')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Func
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.mathtype != dim.mathtype
or current_bl_socket.physical_type != dim.physical_type
):
self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
physical_type=dim.physical_type,
mathtype=dim.mathtype,
default_unit=dim.unit,
show_func_ui=False,
show_info_columns=True,
)
}
# No Loose Value: Remove Input Sockets
elif self.loose_input_sockets:
self.loose_input_sockets = {}
@ -570,60 +541,11 @@ class FilterMathNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info)
# Dim (Op.SetDim)
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
has_dim_func = not ct.FlowSignal.check(dim_func)
has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info)
# Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if has_info and operation is not None:
# Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim:
new_dim = (
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
)
if (
dim_0 is not None
and new_dim is not None
and has_dim_info
and has_dim_params
# Check New Dimension Index Array Sizing
and len(dim_info.dims) == 1
and dim_info.output.rows == 1
and dim_info.output.cols == 1
# Check Lack of Params Symbols
and not dim_params.symbols
# Check Expr Dim | New Dim Compatibility
and info.has_idx_discrete(dim_0)
and dim_info.has_idx_discrete(new_dim)
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
):
# Retrieve Dimension Coordinate Array
## -> It must be strictly compatible.
values = dim_func.realize(dim_params, spux.UNITS_SI)
# Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced.
new_dim_idx = ct.ArrayFlow(
values=values,
unit=spux.convert_to_unit_system(
dim_info.output.unit, spux.UNITS_SI
),
).rescale_to_unit(dim_info.output.unit)
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
return operation.transform_info(
info, dim_0, dim_1, replaced_dim=replaced_dim
)
return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
return ct.FlowSignal.FlowPending

View File

@ -496,7 +496,7 @@ class MapMathNode(base.MaxwellSimNode):
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.info is not None:
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))

View File

@ -20,8 +20,11 @@ import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
@ -37,37 +40,47 @@ class BinaryOperation(enum.StrEnum):
"""Valid operations for the `OperateMathNode`.
Attributes:
Add: Addition w/broadcasting.
Sub: Subtraction w/broadcasting.
Mul: Hadamard-product multiplication.
Div: Hadamard-product based division.
Pow: Elementwise expontiation.
Atan2: Quadrant-respecting arctangent variant.
VecVecDot: Dot product for vectors.
Cross: Cross product.
MatVecDot: Matrix-Vector dot product.
Mul: Scalar multiplication.
Div: Scalar division.
Pow: Scalar exponentiation.
Add: Elementwise addition.
Sub: Elementwise subtraction.
HadamMul: Elementwise multiplication (hadamard product).
HadamPow: Principled shape-aware exponentiation (hadamard power).
Atan2: Quadrant-respecting 2D arctangent.
VecVecDot: Dot product for identically shaped vectors w/transpose.
Cross: Cross product between identically shaped 3D vectors.
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system.
MatMatDot: Matrix-Matrix dot product.
VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
"""
# Number | Number
Add = enum.auto()
Sub = enum.auto()
Mul = enum.auto()
Div = enum.auto()
Pow = enum.auto()
# Elements | Elements
Add = enum.auto()
Sub = enum.auto()
HadamMul = enum.auto()
# HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
Atan2 = enum.auto()
# Vector | Vector
VecVecDot = enum.auto()
Cross = enum.auto()
VecVecOuter = enum.auto()
# Matrix | Vector
MatVecDot = enum.auto()
LinSolve = enum.auto()
LsqSolve = enum.auto()
# Vector | Matrix
VecMatOuter = enum.auto()
# Matrix | Matrix
MatMatDot = enum.auto()
@ -79,19 +92,24 @@ class BinaryOperation(enum.StrEnum):
BO = BinaryOperation
return {
# Number | Number
BO.Mul: ' · r',
BO.Div: ' / r',
BO.Pow: ' ^ r',
# Elements | Elements
BO.Add: ' + r',
BO.Sub: ' - r',
BO.Mul: ' ⊙ r', ## Notation for Hadamard Product
BO.Div: ' / r',
BO.Pow: 'ℓʳ',
BO.Atan2: 'atan2(,r)',
BO.HadamMul: '𝐋𝐑',
# BO.HadamPow: '𝐥 ⊙^ 𝐫',
BO.Atan2: 'atan2(:x, r:y)',
# Vector | Vector
BO.VecVecDot: '𝐥 · 𝐫',
BO.Cross: 'cross(L,R)',
BO.Cross: 'cross(𝐥,𝐫)',
BO.VecVecOuter: '𝐥𝐫',
# Matrix | Vector
BO.MatVecDot: '𝐋 · 𝐫',
BO.LinSolve: '𝐋 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂',
# Vector | Matrix
BO.VecMatOuter: '𝐋𝐫',
# Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑',
}[value]
@ -118,56 +136,104 @@ class BinaryOperation(enum.StrEnum):
"""Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation
ops_number_number = [
ops_el_el = [
BO.Add,
BO.Sub,
BO.Mul,
BO.Div,
BO.Pow,
BO.Atan2,
BO.HadamMul,
# BO.HadamPow,
]
match (info_l.output_shape_len, info_r.output_shape_len):
outl = info_l.output
outr = info_r.output
match (outl.shape_len, outr.shape_len):
# match (ol.shape_len, info_r.output.shape_len):
# Number | *
## Number | Number
case (0, 0):
return ops_number_number
ops = [
BO.Add,
BO.Sub,
BO.Mul,
BO.Div,
BO.Pow,
]
if (
info_l.output.physical_type == spux.PhysicalType.Length
and info_l.output.unit == info_r.output.unit
):
ops += [BO.Atan2]
return ops
## Number | Vector
## -> Broadcasting allows Number|Number ops to work as-is.
case (0, 1):
return ops_number_number
return [BO.Mul] # , BO.HadamPow]
## Number | Matrix
## -> Broadcasting allows Number|Number ops to work as-is.
case (0, 2):
return ops_number_number
return [BO.Mul] # , BO.HadamPow]
# Vector | *
## Vector | Number
case (1, 0):
return ops_number_number
return [BO.Mul] # , BO.HadamPow]
## Vector | Number
## Vector | Vector
case (1, 1):
return [*ops_number_number, BO.VecVecDot, BO.Cross]
ops = []
# Vector | Vector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows > outl.cols and outr.rows > outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Covector | Vector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
if outl.rows < outl.cols and outr.rows > outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Vector | Covector
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
## -> These are both the same operation, in this case.
if outl.rows > outl.cols and outr.rows < outr.cols:
ops += [BO.MatMatDot, BO.VecVecOuter]
# Covector | Covector
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
if outl.rows < outl.cols and outr.rows < outr.cols:
ops += [BO.VecVecDot, BO.VecVecOuter]
# Cross Product
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
if (outl.rows == 3 and outr.rows == 3) or (
outl.cols == 3 and outl.cols == 3
):
ops += [BO.Cross]
return ops
## Vector | Matrix
case (1, 2):
return []
return [BO.VecMatOuter]
# Matrix | *
## Matrix | Number
case (2, 0):
return [*ops_number_number, BO.MatMatDot]
return [BO.Mul] # , BO.HadamPow]
## Matrix | Vector
case (2, 1):
return [BO.MatVecDot, BO.LinSolve, BO.LsqSolve]
prepend_ops = []
# Mat-Vec Dot: Enforce RHS Column Vector
if outr.rows > outl.cols:
prepend_ops += [BO.MatMatDot]
return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
## Matrix | Matrix
case (2, 2):
return [*ops_number_number, BO.MatMatDot]
return [*ops_el_el, BO.MatMatDot]
return []
@ -182,34 +248,86 @@ class BinaryOperation(enum.StrEnum):
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
def unit_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
BO.Add: BO.Add.sp_func,
BO.Sub: BO.Sub.sp_func,
BO.HadamMul: BO.Mul.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
BO.VecVecDot: BO.Mul.sp_func,
BO.Cross: BO.Mul.sp_func,
BO.VecVecOuter: BO.Mul.sp_func,
# Matrix | Vector
## -> A,b in Ax = b have units, and the equality must hold.
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
# Vector | Matrix
BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self]
@property
def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
## TODO: Scale the units of one side to the other.
BO = BinaryOperation
return {
# Number | Number
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
# BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Vector
BO.MatVecDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self]
@ -218,30 +336,12 @@ class BinaryOperation(enum.StrEnum):
# - InfoFlow Transform
####################
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
BO = BinaryOperation
info_largest = (
info_l if info_l.output_shape_len > info_l.output_shape_len else info_l
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
return info_l.operate_output(
info_r,
lambda a, b: self.sp_func([a, b]),
lambda a, b: self.unit_func([a, b]),
)
info_any = info_largest
return {
# Number | * or * | Number
BO.Add: info_largest,
BO.Sub: info_largest,
BO.Mul: info_largest,
BO.Div: info_largest,
BO.Pow: info_largest,
BO.Atan2: info_largest,
# Vector | Vector
BO.VecVecDot: info_any,
BO.Cross: info_any,
# Matrix | Vector
BO.MatVecDot: info_r,
BO.LinSolve: info_r,
BO.LsqSolve: info_r,
# Matrix | Matrix
BO.MatMatDot: info_any,
}[self]
####################
@ -367,9 +467,9 @@ class OperateMathNode(base.MaxwellSimNode):
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
if has_expr_l_value and has_expr_r_value and operation is not None:
operation.sp_func([expr_l, expr_r])
return operation.sp_func([expr_l, expr_r])
return ct.Flowsignal.FlowPending
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
@ -396,7 +496,7 @@ class OperateMathNode(base.MaxwellSimNode):
## -> The operation enum directly provides the appropriate function.
if has_expr_l and has_expr_r:
return (expr_l | expr_r).compose_within(
operation.jax_func,
enclosing_func=operation.jax_func,
supports_jax=True,
)
return ct.FlowSignal.FlowPending

View File

@ -17,13 +17,13 @@
"""Declares `TransformMathNode`."""
import enum
import functools
import typing as typ
import bpy
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
@ -51,6 +51,9 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform
FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto()
ConvertIdxUnit = enum.auto()
SetIdxUnit = enum.auto()
FirstColToFirstIdx = enum.auto()
# Fold
IntDimToComplex = enum.auto()
@ -58,8 +61,8 @@ class TransformOperation(enum.StrEnum):
DimsToMat = enum.auto()
# Fourier
FFT1D = enum.auto()
InvFFT1D = enum.auto()
FT1D = enum.auto()
InvFT1D = enum.auto()
# TODO: Affine
## TODO
@ -74,17 +77,22 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform
TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓',
TO.ConvertIdxUnit: 'Convert Dim',
TO.SetIdxUnit: 'Set Dim',
TO.FirstColToFirstIdx: '1st Col → Dim',
# Fold
TO.IntDimToComplex: '',
TO.DimToVec: '→ Vector',
TO.DimsToMat: '→ Matrix',
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier
TO.FFT1D: 't → 𝑓',
TO.InvFFT1D: '𝑓 t',
TO.FT1D: '𝑓',
TO.InvFT1D: '𝑓',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
def to_icon(_: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
@ -98,121 +106,216 @@ class TransformOperation(enum.StrEnum):
)
####################
# - Ops from Shape
# - Methods
####################
@property
def num_dim_inputs(self) -> None:
"""The number of axes that should be passed as inputs to `func_jax` when evaluating it.
Especially useful for `ParamFlow`, when deciding whether to pass an integer-axis argument based on a user-selected dimension.
"""
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: 1,
TO.VacWLToFreq: 1,
TO.ConvertIdxUnit: 1,
TO.SetIdxUnit: 1,
TO.FirstColToFirstIdx: 0,
# Fold
TO.IntDimToComplex: 0,
TO.DimToVec: 0,
TO.DimsToMat: 0,
## TODO: Vector to new last-dim integer
## TODO: Matrix to two last-dim integers
# Fourier
TO.FT1D: 1,
TO.InvFT1D: 1,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
match self:
case TO.FreqToVacWL | TO.FT1D:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Freq
]
case TO.VacWLToFreq | TO.InvFT1D:
return [
dim
for dim in info.dims
if dim.physical_type is spux.PhysicalType.Length
]
case TO.ConvertIdxUnit | TO.SetIdxUnit:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
## ColDimToComplex: Implicit Last Dimension
## DimToVec: Implicit Last Dimension
## DimsToMat: Implicit Last 2 Dimensions
case TO.FT1D | TO.InvFT1D:
# Filter by Axis Uniformity
## -> FT requires uniform axis (aka. must be RangeFlow).
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
return []
@staticmethod
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]:
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
operations = []
# Covariant Transform
## Freq <-> VacWL
for dim in info.dims:
if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL)
## Freq -> VacWL
if TO.FreqToVacWL.valid_dims(info):
operations += [TO.FreqToVacWL]
if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq)
## VacWL -> Freq
if TO.VacWLToFreq.valid_dims(info):
operations += [TO.VacWLToFreq]
## Convert Index Unit
if TO.ConvertIdxUnit.valid_dims(info):
operations += [TO.ConvertIdxUnit]
if TO.SetIdxUnit.valid_dims(info):
operations += [TO.SetIdxUnit]
## Column to First Index (Array)
if (
len(info.dims) == 2 # noqa: PLR2004
and info.first_dim.mathtype is spux.MathType.Integer
and info.last_dim.mathtype is spux.MathType.Integer
and info.output.shape_len == 0
):
operations += [TO.FirstColToFirstIdx]
# Fold
## (Last) Int Dim (=2) to Complex
if len(info.dims) >= 1:
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
operations.append(TO.IntDimToComplex)
## Last Dim -> Complex
if (
info.dims
# Output is Int|Rat|Real
and (
info.output.mathtype
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
)
# Last Axis is Integer of Length 2
and info.last_dim.mathtype is spux.MathType.Integer
and info.has_idx_labels(info.last_dim)
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
):
operations += [TO.IntDimToComplex]
## To Vector
if len(info.dims) >= 1:
operations.append(TO.DimToVec)
## Last Dim -> Vector
if len(info.dims) >= 1 and info.output.shape_len == 0:
operations += [TO.DimToVec]
## To Matrix
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(TO.DimsToMat)
## Last Dim -> Matrix
if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
operations += [TO.DimsToMat]
# Fourier
## 1D Fourier
if info.dims:
last_physical_type = info.last_dim.physical_type
if last_physical_type == spux.PhysicalType.Time:
operations.append(TO.FFT1D)
if last_physical_type == spux.PhysicalType.Freq:
operations.append(TO.InvFFT1D)
if TO.FT1D.valid_dims(info):
operations += [TO.FT1D]
if TO.InvFT1D.valid_dims(info):
operations += [TO.InvFT1D]
return operations
####################
# - Function Properties
####################
@property
def sp_func(self):
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fold
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.t, sim_symbols.freq
),
TO.InvFFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.freq, sim_symbols.t
),
}[self]
@property
@functools.cached_property
def jax_func(self):
TO = TransformOperation
return {
# Covariant Transform
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
TO.FreqToVacWL: lambda expr, axis: jnp.flip(expr, axis=axis),
TO.VacWLToFreq: lambda expr, axis: jnp.flip(expr, axis=axis),
TO.ConvertIdxUnit: lambda expr: expr,
TO.SetIdxUnit: lambda expr: expr,
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
# Fold
## -> To Complex: With a little imagination, this is a noop :)
## -> **Requires** dims[-1] to be integer-indexed w/length of 2.
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(),
## -> To Complex: This should generally be a no-op.
TO.IntDimToComplex: lambda expr: jnp.squeeze(
expr.view(dtype=jnp.complex64), axis=-1
),
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: jnp.fft(expr),
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
TO.FT1D: lambda expr, axis: jnp.fft(expr, axis=axis),
TO.InvFT1D: lambda expr, axis: jnp.ifft(expr, axis=axis),
}[self]
def transform_info(
self,
info: ct.InfoFlow | None,
data: jtyp.Shaped[jtyp.Array, '...'] | None = None,
info: ct.InfoFlow,
dim: sim_symbols.SimSymbol | None = None,
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
new_dim_name: str | None = None,
unit: spux.Unit | None = None,
) -> ct.InfoFlow | None:
physical_type: spux.PhysicalType | None = None,
) -> ct.InfoFlow:
TO = TransformOperation
if not info.dims:
return None
return {
# Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.last_dim),
(f_dim := dim),
[
sim_symbols.wl(spu.nanometer),
sim_symbols.wl(unit),
info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spu.nanometer,
new_unit=unit,
),
],
),
TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.last_dim),
(wl_dim := dim),
[
sim_symbols.freq(spux.THz),
sim_symbols.freq(unit),
info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spux.THz,
new_unit=unit,
),
],
),
TO.ConvertIdxUnit: lambda: info.replace_dim(
dim,
dim.update(unit=unit),
(
info.dims[dim].rescale_to_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.SetIdxUnit: lambda: info.replace_dim(
dim,
dim.update(
sym_name=new_dim_name, physical_type=physical_type, unit=unit
),
(
info.dims[dim].correct_unit(unit)
if info.has_idx_discrete(dim)
else None ## Continuous -- dim SimSymbol already scaled
),
),
TO.FirstColToFirstIdx: lambda: info.replace_dim(
info.first_dim,
info.first_dim.update(
mathtype=spux.MathType.from_jax_array(data_col),
unit=unit,
),
ct.ArrayFlow(values=data_col, unit=unit),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
@ -220,21 +323,31 @@ class TransformOperation(enum.StrEnum):
TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier
TO.FFT1D: lambda: info.replace_dim(
info.last_dim,
TO.FT1D: lambda: info.replace_dim(
dim,
[
sim_symbols.freq(spux.THz),
None,
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
info.dims[dim].bound_fourier_transform,
],
),
TO.InvFFT1D: info.replace_dim(
TO.InvFT1D: lambda: info.replace_dim(
info.last_dim,
[
sim_symbols.t(spu.second),
None,
# FT'ed Unit: Reciprocal of the Original Unit
dim.update(
unit=1 / dim.unit if dim.unit is not None else 1
), ## TODO: Okay to not scale interval?
# FT'ed Bounds: Reciprocal of the Original Unit
## -> Note the midpoint may revert to 0.
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
info.dims[dim].bound_inv_fourier_transform,
],
),
}.get(self, lambda: info)()
}[self]()
####################
@ -274,7 +387,6 @@ class TransformMathNode(base.MaxwellSimNode):
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
@ -304,45 +416,125 @@ class TransformMathNode(base.MaxwellSimNode):
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
TransformOperation.by_element_shape(self.expr_info)
TransformOperation.by_info(self.expr_info)
)
]
return []
####################
# - Properties: Dimension Selection
####################
active_dim: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
def search_dims(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None and self.operation is not None:
return [
(dim.name, dim.name_pretty, dim.name, '', i)
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
]
return []
@bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
def dim(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim is not None:
return self.expr_info.dim_by_name(self.active_dim)
return None
####################
# - Properties: New Dimension Properties
####################
new_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
new_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
active_new_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(),
cb_depends_on={'dim', 'new_physical_type'},
)
def search_units(self) -> list[ct.BLEnumElement]:
if self.dim is not None:
if self.dim.physical_type is not spux.PhysicalType.NonPhysical:
unit_name = sp.sstr(self.dim.unit)
return [
(
sp.sstr(unit),
spux.sp_to_str(unit),
sp.sstr(unit),
'',
0,
)
for unit in self.dim.physical_type.valid_units
]
if self.dim.unit is not None:
unit_name = sp.sstr(self.dim.unit)
return [
(
unit_name,
spux.sp_to_str(self.dim.unit),
unit_name,
'',
0,
)
]
if self.new_physical_type is not spux.PhysicalType.NonPhysical:
return [
(
sp.sstr(unit),
spux.sp_to_str(unit),
sp.sstr(unit),
'',
i,
)
for i, unit in enumerate(self.new_physical_type.valid_units)
]
return []
@bl_cache.cached_bl_property(depends_on={'active_new_unit'})
def new_unit(self) -> spux.Unit:
if self.active_new_unit is not None:
return spux.unit_str_to_unit(self.active_new_unit)
return None
####################
# - UI
####################
def draw_label(self):
if self.operation is not None:
return 'Transform: ' + TransformOperation.to_name(self.operation)
return 'T: ' + TransformOperation.to_name(self.operation)
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
if self.operation is not None and self.operation.num_dim_inputs == 1:
TO = TransformOperation
layout.prop(self, self.blfields['active_dim'], text='')
if self.operation in [TO.ConvertIdxUnit, TO.SetIdxUnit]:
col = layout.column(align=True)
if self.operation is TransformOperation.ConvertIdxUnit:
col.prop(self, self.blfields['active_new_unit'], text='')
if self.operation is TransformOperation.SetIdxUnit:
col.prop(self, self.blfields['new_physical_type'], text='')
row = col.row(align=True)
row.prop(self, self.blfields['new_name'], text='')
row.prop(self, self.blfields['active_new_unit'], text='')
####################
# - Compute: Func / Array
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
props={'operation'},
input_sockets={'Expr'},
)
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
has_expr_value = not ct.FlowSignal.check(expr)
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None:
return operation.sp_func(expr)
return ct.Flowsignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@ -354,54 +546,103 @@ class TransformMathNode(base.MaxwellSimNode):
)
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
lazy_func = input_sockets['Expr']
has_expr = not ct.FlowSignal.check(expr)
has_lazy_func = not ct.FlowSignal.check(lazy_func)
if has_expr and operation is not None:
return expr.compose_within(
if has_lazy_func and operation is not None:
return lazy_func.compose_within(
operation.jax_func,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Info|Params
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'operation'},
props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
)
def compute_info(
self, props: dict, input_sockets: dict
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
operation = props['operation']
info = input_sockets['Expr']
info = input_sockets['Expr'][ct.FlowKind.Info]
has_info = not ct.FlowSignal.check(info)
dim = props['dim']
new_name = props['new_name']
new_unit = props['new_unit']
new_physical_type = props['new_physical_type']
if has_info and operation is not None:
transformed_info = operation.transform_info(info)
# First Column to First Index
## -> We have to evaluate the lazy function at this point.
## -> It's the only way to get at the column data.
if operation is TransformOperation.FirstColToFirstIdx:
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
params = input_sockets['Expr'][ct.FlowKind.Params]
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_params = not ct.FlowSignal.check(lazy_func)
if transformed_info is None:
if has_lazy_func and has_params and not params.symbols:
data = lazy_func.realize(params)
if data.shape is not None and len(data.shape) == 2:
data_col = data[:, 0]
return operation.transform_info(info, data_col=data_col)
return ct.FlowSignal.FlowPending
return transformed_info
# Check Not-Yet-Updated Dimension
## - Operation changes before dimensions.
## - If InfoFlow is requested in this interim, big problem.
if dim is None and operation.num_dim_inputs > 0:
return ct.FlowSignal.FlowPending
return operation.transform_info(
info,
dim=dim,
new_dim_name=new_name,
unit=new_unit,
physical_type=new_physical_type,
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
props={'operation', 'dim'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params},
input_socket_kinds={'Expr': {ct.FlowKind.Params, ct.FlowKind.Info}},
)
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
if has_params:
return input_sockets['Expr']
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
operation = props['operation']
dim = props['dim']
if has_info and has_params and operation is not None:
# Axis Required: Insert by-Dimension
## -> Some transformations ex. FT require setting an axis.
## -> The user selects which dimension the op should be done along.
## -> This dimension is converted to an axis integer.
## -> Finally, we pass the argument via params.
if operation.num_dim_inputs == 1:
axis = info.dim_axis(dim) if dim is not None else None
return params.compose_within(enclosing_func_args=[axis])
return params
return ct.FlowSignal.FlowPending

View File

@ -21,6 +21,7 @@ import bpy
import jaxtyping as jtyp
import matplotlib.axis as mpl_ax
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
@ -74,8 +75,33 @@ class VizMode(enum.StrEnum):
SqueezedHeatmap2D = enum.auto()
Heatmap3D = enum.auto()
####################
# - UI
####################
@staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
def to_name(value: typ.Self) -> str:
return {
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
VizMode.Bar: 'Bar',
VizMode.Curves2D: 'Curves',
VizMode.FilledCurves2D: 'Filled Curves',
VizMode.Heatmap2D: 'Heatmap',
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
VizMode.Heatmap3D: 'Heatmap (3D)',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
####################
# - Validity
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self] | None:
"""Given the input `InfoFlow`, deduce which visualization modes are valid to use with the described data."""
Z = spux.MathType.Integer
R = spux.MathType.Real
VM = VizMode
@ -102,15 +128,18 @@ class VizMode(enum.StrEnum):
],
}.get(
(
tuple([dim.mathtype for dim in info.dims.values()]),
tuple([dim.mathtype for dim in info.dims]),
(info.output.rows, info.output.cols, info.output.mathtype),
),
[],
)
@staticmethod
def to_plotter(
value: typ.Self,
####################
# - Properties
####################
@property
def mpl_plotter(
self,
) -> typ.Callable[
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]:
@ -124,25 +153,7 @@ class VizMode(enum.StrEnum):
VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
# NO PLOTTER: VizMode.SqueezedHeatmap2D
# NO PLOTTER: VizMode.Heatmap3D
}[value]
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
VizMode.Bar: 'Bar',
VizMode.Curves2D: 'Curves',
VizMode.FilledCurves2D: 'Filled Curves',
VizMode.Heatmap2D: 'Heatmap',
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
VizMode.Heatmap3D: 'Heatmap (3D)',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> ct.BLIcon:
return ''
}[self]
class VizTarget(enum.StrEnum):
@ -181,6 +192,10 @@ class VizTarget(enum.StrEnum):
return ''
sym_x_um = sim_symbols.space_x(spu.um)
x_um = sym_x_um.sp_symbol
class VizNode(base.MaxwellSimNode):
"""Node for visualizing simulation data, by querying its monitors.
@ -188,7 +203,6 @@ class VizNode(base.MaxwellSimNode):
Attributes:
colormap: Colormap to apply to 0..1 output.
"""
node_type = ct.NodeType.Viz
@ -201,8 +215,8 @@ class VizNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
default_symbols=[sim_symbols.x],
default_value=2 * sim_symbols.x.sp_symbol,
default_symbols=[sym_x_um],
default_value=sp.exp(-(x_um**2)),
),
}
output_sockets: typ.ClassVar = {
@ -240,16 +254,21 @@ class VizNode(base.MaxwellSimNode):
return None
viz_mode: enum.StrEnum = bl_cache.BLField(
viz_mode: VizMode = bl_cache.BLField(
enum_cb=lambda self, _: self.search_viz_modes(),
cb_depends_on={'expr_info'},
)
viz_target: enum.StrEnum = bl_cache.BLField(
viz_target: VizTarget = bl_cache.BLField(
enum_cb=lambda self, _: self.search_targets(),
cb_depends_on={'viz_mode'},
)
# Mode-Dependent Properties
# Plot
plot_width: float = bl_cache.BLField(6.0, abs_min=0.1)
plot_height: float = bl_cache.BLField(3.0, abs_min=0.1)
plot_dpi: int = bl_cache.BLField(150, abs_min=25)
# Pixels
colormap: image_ops.Colormap = bl_cache.BLField(
image_ops.Colormap.Viridis,
)
@ -267,7 +286,7 @@ class VizNode(base.MaxwellSimNode):
VizMode.to_icon(viz_mode),
i,
)
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info))
for i, viz_mode in enumerate(VizMode.by_info(self.expr_info))
]
return []
@ -300,9 +319,22 @@ class VizNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, self.blfields['viz_mode'], text='')
col.prop(self, self.blfields['viz_target'], text='')
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
col.prop(self, self.blfields['colormap'], text='')
if self.viz_target is VizTarget.Plot2D:
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Width/Height/DPI')
row = col.row(align=True)
row.prop(self, self.blfields['plot_width'], text='')
row.prop(self, self.blfields['plot_height'], text='')
row = col.row(align=True)
col.prop(self, self.blfields['plot_dpi'], text='')
####################
# - Events
####################
@ -320,16 +352,16 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
# Provide Sockets for Symbol Realization
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != {
dim.name for dim in params.symbols if dim in info.dims
sym.name for sym in params.symbols if sym in info.dims
}:
self.loose_input_sockets = {
dim_name: sockets.ExprSocketDef(**expr_info)
for dim_name, expr_info in params.sym_expr_infos(
info, use_range=True
use_range=True
).items()
}
@ -343,7 +375,14 @@ class VizNode(base.MaxwellSimNode):
'Preview',
kind=ct.FlowKind.Value,
# Loaded
props={'viz_mode', 'viz_target', 'colormap'},
props={
'viz_mode',
'viz_target',
'colormap',
'plot_width',
'plot_height',
'plot_dpi',
},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
@ -359,7 +398,14 @@ class VizNode(base.MaxwellSimNode):
#####################
@events.on_show_plot(
managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'},
props={
'viz_mode',
'viz_target',
'colormap',
'plot_width',
'plot_height',
'plot_dpi',
},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
@ -370,7 +416,6 @@ class VizNode(base.MaxwellSimNode):
def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets
) -> None:
# Retrieve Inputs
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
@ -378,51 +423,59 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
if (
not has_info
or not has_params
or props['viz_mode'] is None
or props['viz_target'] is None
):
return
plot = managed_objs['plot']
viz_mode = props['viz_mode']
viz_target = props['viz_target']
if has_info and has_params and viz_mode is not None and viz_target is not None:
# Realize Data w/Realized Symbols
## -> The loose input socket values are user-selected symbol values.
## -> These expressions are used to realize the lazy data.
## -> `.realize()` ensures all ex. units are correctly conformed.
realized_syms = {
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
}
output_data = lazy_func.realize(params, symbol_values=realized_syms)
# Compute Ranges for Symbols from Loose Sockets
## -> In a quite nice turn of events, all this is cached lookups.
## -> ...Unless something changed, in which case, well. It changed.
symbol_array_values = {
sim_syms: (
loose_input_sockets[sim_syms]
.rescale_to_unit(sim_syms.unit)
.realize_array
)
for sim_syms in params.sorted_symbols
}
data = lazy_func.realize(params, symbol_values=symbol_array_values)
# Replace InfoFlow Indices w/Realized Symbolic Ranges
## -> This ensures correct axis scaling.
if params.symbols:
info = info.replace_dims(symbol_array_values)
match props['viz_target']:
case VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True,
data = {
dim: (
realized_syms[dim].values
if dim in realized_syms
else info.dims[dim]
)
for dim in info.dims
} | {info.output: output_data}
case VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image(
data,
colormap=props['colormap'],
bl_select=True,
)
# Match Viz Type & Perform Visualization
## -> Viz Target determines how to plot.
## -> Viz Mode may help select a particular plotting method.
## -> Other parameters may be uses, depending on context.
match viz_target:
case VizTarget.Plot2D:
plot_width = props['plot_width']
plot_height = props['plot_height']
plot_dpi = props['plot_dpi']
plot.mpl_plot_to_image(
lambda ax: viz_mode.mpl_plotter(data, ax),
width_inches=plot_width,
height_inches=plot_height,
dpi=plot_dpi,
bl_select=True,
)
case VizTarget.PixelsPlane:
raise NotImplementedError
case VizTarget.Pixels:
colormap = props['colormap']
if colormap is not None:
plot.map_2d_to_image(
data,
colormap=colormap,
bl_select=True,
)
case VizTarget.Voxels:
raise NotImplementedError
case VizTarget.PixelsPlane:
raise NotImplementedError
case VizTarget.Voxels:
raise NotImplementedError
####################

View File

@ -610,8 +610,8 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Anyone needing results will need to wait on preinit().
return ct.FlowSignal.FlowInitializing
if optional:
return ct.FlowSignal.NoFlow
# if optional:
return ct.FlowSignal.NoFlow
msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket'
raise ValueError(msg)
@ -659,11 +659,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
return output_socket_methods[0](self)
# Auxiliary Fallbacks
if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
return ct.FlowSignal.NoFlow
return ct.FlowSignal.NoFlow
# if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
# return ct.FlowSignal.NoFlow
msg = f'No output method for ({output_socket_name}, {kind})'
raise ValueError(msg)
# msg = f'No output method for ({output_socket_name}, {kind})'
# raise ValueError(msg)
####################
# - Event Trigger

View File

@ -30,6 +30,7 @@ class ExprConstantNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
show_name_selector=True,
),
}
output_sockets: typ.ClassVar = {

View File

@ -17,8 +17,11 @@
import typing as typ
import bpy
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, sci_constants
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
@ -30,15 +33,19 @@ class ScientificConstantNode(base.MaxwellSimNode):
bl_label = 'Scientific Constant'
output_sockets: typ.ClassVar = {
'Value': sockets.ExprSocketDef(),
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
}
####################
# - Properties
####################
sci_constant: str = bl_cache.BLField(
use_symbol: bool = bl_cache.BLField(False)
sci_constant_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerU
)
sci_constant_str: str = bl_cache.BLField(
'',
prop_ui=True,
str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text),
)
@ -52,27 +59,139 @@ class ScientificConstantNode(base.MaxwellSimNode):
if edit_text.lower() in name.lower()
]
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
def sci_constant(self) -> spux.SympyExpr | None:
"""Retrieve the expression for the scientific constant."""
return sci_constants.SCI_CONSTANTS.get(self.sci_constant_str)
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
def sci_constant_info(self) -> spux.SympyExpr | None:
"""Retrieve the information for the selected scientific constant."""
return sci_constants.SCI_CONSTANTS_INFO.get(self.sci_constant_str)
@bl_cache.cached_bl_property(
depends_on={'sci_constant', 'sci_constant_info', 'sci_constant_name'}
)
def sci_constant_sym(self) -> spux.SympyExpr | None:
"""Retrieve a symbol for the scientific constant."""
if self.sci_constant is not None and self.sci_constant_info is not None:
unit = self.sci_constant_info['units']
return sim_symbols.SimSymbol(
sym_name=self.sci_constant_name,
mathtype=spux.MathType.from_expr(self.sci_constant),
# physical_type= ## TODO: Formalize unit w/o physical_type
unit=unit,
is_constant=True,
)
return None
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
col.prop(self, self.blfields['sci_constant'], text='')
col.prop(self, self.blfields['sci_constant_str'], text='')
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Assign Symbol')
col.prop(self, self.blfields['sci_constant_name'], text='')
col.prop(self, self.blfields['use_symbol'], text='Use Symbol', toggle=True)
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
if self.sci_constant:
col.label(
text=f'Units: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["units"]}'
)
col.label(
text=f'Uncertainty: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["uncertainty"]}'
)
box = col.box()
split = box.split(factor=0.25, align=True)
# Left: Units
_col = split.column(align=True)
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='Src')
if self.sci_constant_info:
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='Unit')
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='Err')
# Right: Values
_col = split.column(align=True)
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text='CODATA2018')
if self.sci_constant_info:
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text=f'{self.sci_constant_info["units"]}')
row = _col.row(align=True)
# row.alignment = 'CENTER'
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
####################
# - Output
####################
@events.computes_output_socket('Value', props={'sci_constant'})
def compute_value(self, props: dict) -> typ.Any:
return sci_constants.SCI_CONSTANTS[props['sci_constant']]
@events.computes_output_socket(
'Expr',
props={'use_symbol', 'sci_constant', 'sci_constant_sym'},
)
def compute_value(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
if props['use_symbol'] and sci_constant_sym is not None:
return sci_constant_sym.sp_symbol
if sci_constant is not None:
return sci_constant
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
props={'sci_constant', 'sci_constant_sym'},
)
def compute_lazy_func(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
if sci_constant is not None:
return ct.FuncFlow(
func=sp.lambdify(
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
),
func_args=[sci_constant_sym],
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'sci_constant_sym'},
)
def compute_info(self, props: dict) -> typ.Any:
sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None:
return ct.InfoFlow(output=sci_constant_sym)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
props={'sci_constant'},
)
def compute_params(self, props: dict) -> typ.Any:
sci_constant = props['sci_constant']
if sci_constant is not None:
return ct.ParamsFlow(func_args=[sci_constant])
return ct.FlowSignal.FlowPending
####################

View File

@ -95,62 +95,35 @@ class DataFileImporterNode(base.MaxwellSimNode):
####################
# - Info Guides
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Data
)
output_mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
output_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
output_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
enum_cb=lambda self, _: self.search_units(self.output_physical_type),
cb_depends_on={'output_physical_type'},
)
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_0_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'dim_0_physical_type'},
)
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerB
)
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_1_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
cb_depends_on={'dim_1_physical_type'},
)
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerC
)
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_2_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
cb_depends_on={'dim_2_physical_type'},
)
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerD
)
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
dim_4_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerE
)
dim_3_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
cb_depends_on={'dim_3_physical_type'},
dim_5_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerF
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
@ -161,19 +134,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
]
return []
def dim(self, i: int):
dim_name = getattr(self, f'dim_{i}_name')
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
dim_unit = getattr(self, f'dim_{i}_unit')
return sim_symbols.SimSymbol(
sym_name=dim_name,
mathtype=dim_mathtype,
physical_type=dim_physical_type,
unit=spux.unit_str_to_unit(dim_unit),
)
####################
# - UI
####################
@ -202,19 +162,21 @@ class DataFileImporterNode(base.MaxwellSimNode):
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw loaded properties."""
for i in range(len(self.expr_info.dims)):
col = layout.column(align=True)
col = layout.column(align=True)
if self.expr_info is not None:
for i in range(len(self.expr_info.dims)):
col.prop(self, self.blfields[f'dim_{i}_name'], text=f'Dim {i}')
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text=f'Load Dim {i}')
row.label(text='Output')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
row.prop(self, self.blfields['output_name'], text='')
row.prop(self, self.blfields['output_mathtype'], text='')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
row.prop(self, self.blfields['output_physical_type'], text='')
row.prop(self, self.blfields['output_unit'], text='')
####################
# - FlowKind.Array|Func
@ -271,7 +233,8 @@ class DataFileImporterNode(base.MaxwellSimNode):
'Expr',
kind=ct.FlowKind.Info,
# Loaded
props={'output_name', 'output_physical_type', 'output_unit'},
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
| {f'dim_{i}_name' for i in range(6)},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Func},
)
@ -285,32 +248,31 @@ class DataFileImporterNode(base.MaxwellSimNode):
A completely empty `ParamsFlow`, ready to be composed.
"""
expr = output_sockets['Expr']
has_expr_func = not ct.FlowSignal.check(expr)
if has_expr_func:
data = expr.func_jax()
# Deduce Dimensionality
_shape = data.shape
shape = _shape if _shape is not None else ()
dim_syms = [self.dim(i) for i in range(len(shape))]
# Deduce Dimension Symbols
## -> They are all chronically integer indices.
## -> The FilterNode can be used to "steal" an index from the data.
shape = data.shape if data.shape is not None else ()
dims = {
sim_symbols.idx(None).update(
sym_name=props[f'dim_{i}_name'],
interval_finite_z=(0, elements),
interval_inf=(False, False),
interval_closed=(True, True),
): [str(j) for j in range(elements)]
for i, elements in enumerate(shape)
}
# Return InfoFlow
return ct.InfoFlow(
dims={
dim_sym: ct.RangeFlow(
start=sp.S(0),
stop=sp.S(shape[i] - 1),
steps=shape[i],
unit=self.dim(i).unit,
)
for i, dim_sym in enumerate(dim_syms)
},
dims=dims,
output=sim_symbols.SimSymbol(
sym_name=props['output_name'],
mathtype=props['output_mathtype'],
physical_type=props['output_physical_type'],
unit=props['output_unit'],
),
)
return ct.FlowSignal.FlowPending

View File

@ -259,9 +259,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
)
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
return ct.RangeFlow(
start=props['freq_range'][0] / spux.THz,
stop=props['freq_range'][1] / spux.THz,
steps=0,
start=spu.scale_to_unit(['freq_range'][0], spux.THz),
stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
scaling=ct.ScalingMode.Lin,
unit=spux.THz,
)
@ -273,9 +272,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
)
def compute_valid_wls_lazy(self, props) -> sp.Expr:
return ct.RangeFlow(
start=props['wl_range'][0] / spu.nm,
stop=props['wl_range'][0] / spu.nm,
steps=0,
start=spu.scale_to_unit(['wl_range'][0], spu.nm),
stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
scaling=ct.ScalingMode.Lin,
unit=spu.nm,
)

View File

@ -73,43 +73,138 @@ class ViewerNode(base.MaxwellSimNode):
####################
# - Properties
####################
print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
auto_plot: bool = bl_cache.BLField(False)
auto_expr: bool = bl_cache.BLField(True)
debug_mode: bool = bl_cache.BLField(False)
# Debug Mode
console_print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
auto_plot: bool = bl_cache.BLField(True)
auto_3d_preview: bool = bl_cache.BLField(True)
####################
# - Properties: Computed FlowKinds
####################
@events.on_value_changed(
socket_name='Any',
)
def on_input_changed(self) -> None:
self.input_flow = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def input_flow(self) -> dict[ct.FlowKind, typ.Any | None]:
input_flow = {}
for flow_kind in list(ct.FlowKind):
flow = self._compute_input('Any', kind=flow_kind)
has_flow = not ct.FlowSignal.check(flow)
if has_flow:
input_flow |= {flow_kind: flow}
else:
input_flow |= {flow_kind: None}
return input_flow
####################
# - Property: Input Expression String Lines
####################
@bl_cache.cached_bl_property(depends_on={'input_flow'})
def input_expr_str_entries(self) -> list[list[str]] | None:
value = self.input_flow.get(ct.FlowKind.Value)
def sp_pretty(v: spux.SympyExpr) -> spux.SympyExpr:
## sp.pretty makes new lines and wreaks havoc.
return spux.sp_to_str(v.n(4))
if isinstance(value, spux.SympyType):
if isinstance(value, sp.MatrixBase):
return [
[sp_pretty(value[row, col]) for col in range(value.shape[1])]
for row in range(value.shape[0])
]
return [[sp_pretty(value)]]
return None
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['print_kind'], text='')
row = layout.row(align=True)
# Debug Mode On/Off
row.prop(self, self.blfields['debug_mode'], text='Debug', toggle=True)
# Automatic Expression Printing
row.prop(self, self.blfields['auto_expr'], text='Expr', toggle=True)
# Debug Mode Operators
if self.debug_mode:
layout.prop(self, self.blfields['console_print_kind'], text='')
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
split = layout.split(factor=0.4)
# Live Expression
if self.debug_mode:
layout.operator(ConsoleViewOperator.bl_idname, text='Console Print')
# Split LHS
col = split.column(align=False)
col.label(text='Console')
col.label(text='Plot')
col.label(text='3D')
split = layout.split(factor=0.4)
# Split RHS
col = split.column(align=False)
# Split LHS
col = split.column(align=False)
col.label(text='Plot')
col.label(text='3D')
## Console Options
col.operator(ConsoleViewOperator.bl_idname, text='Print')
# Split RHS
col = split.column(align=False)
## Plot Options
row = col.row(align=True)
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
row.operator(
RefreshPlotViewOperator.bl_idname,
text='',
icon='FILE_REFRESH',
)
## Plot Options
row = col.row(align=True)
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
row.operator(
RefreshPlotViewOperator.bl_idname,
text='',
icon='FILE_REFRESH',
)
## 3D Preview Options
row = col.row(align=True)
row.prop(self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True)
## 3D Preview Options
row = col.row(align=True)
row.prop(
self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True
)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
# Live Expression
if self.auto_expr and self.input_expr_str_entries is not None:
box = layout.box()
expr_rows = len(self.input_expr_str_entries)
expr_cols = len(self.input_expr_str_entries[0])
shape_str = (
f'({expr_rows}×{expr_cols})'
if expr_rows != 1 or expr_cols != 1
else '(Scalar)'
)
row = box.row()
row.alignment = 'CENTER'
row.label(text=f'Expr {shape_str}')
if (
len(self.input_expr_str_entries) == 1
and len(self.input_expr_str_entries[0]) == 1
):
row = box.row()
row.alignment = 'CENTER'
row.label(text=self.input_expr_str_entries[0][0])
else:
grid = box.grid_flow(
row_major=True,
columns=len(self.input_expr_str_entries[0]),
align=True,
)
for row in self.input_expr_str_entries:
for entry in row:
grid.label(text=entry)
####################
# - Methods
@ -119,7 +214,7 @@ class ViewerNode(base.MaxwellSimNode):
return
log.info('Printing to Console')
data = self._compute_input('Any', kind=self.print_kind, optional=True)
data = self._compute_input('Any', kind=self.console_print_kind, optional=True)
if isinstance(data, spux.SympyType):
console.print(sp.pretty(data, use_unicode=True))

View File

@ -40,6 +40,8 @@ _max_e_socket_def = sockets.ExprSocketDef(
)
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
t_ps = sim_symbols.t(spu.picosecond)
class TemporalShapeNode(base.MaxwellSimNode):
"""Declare a source-time dependence for use in simulation source nodes."""
@ -82,8 +84,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_steps=100,
),
'Envelope': sockets.ExprSocketDef(
default_symbols=[sim_symbols.t],
default_value=10 * sim_symbols.t.sp_symbol,
default_symbols=[t_ps],
default_value=10 * t_ps.sp_symbol,
),
},
}

View File

@ -593,6 +593,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
ValueError: When referencing a socket that's meant to be directly referenced.
"""
kind_data_map = {
ct.FlowKind.Capabilities: lambda: self.capabilities,
ct.FlowKind.Value: lambda: self.value,
ct.FlowKind.Array: lambda: self.array,
ct.FlowKind.Func: lambda: self.lazy_func,

View File

@ -111,29 +111,38 @@ class ExprBLSocket(base.MaxwellSimSocket):
bl_label = 'Expr'
####################
# - Properties
# - Socket Interface
####################
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
# Symbols
####################
# - Symbols
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
@property
def symbols(self) -> set[sp.Symbol]:
"""Current symbols as an unordered set."""
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
@bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sp.Symbol]:
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
"""Sympy symbols as an unordered set."""
return {sim_symbol.sp_symbol_matsym for sim_symbol in self.symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
"""Current symbols as a sorted list."""
return sorted(self.symbols, key=lambda sym: sym.name)
# Unit
@bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
"""Computes `sympy` symbols from `self.sorted_symbols`."""
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
####################
# - Units
####################
active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'physical_type'},
@ -148,6 +157,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
]
return []
@bl_cache.cached_bl_property(depends_on={'active_unit'})
def unit(self) -> spux.Unit | None:
"""Gets the current active unit.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_unit is not None:
return spux.unit_str_to_unit(self.active_unit)
return None
@property
def unit_factor(self) -> spux.Unit | None:
return sp.Integer(1) if self.unit is None else self.unit
prev_unit: str | None = bl_cache.BLField(None)
####################
# - UI Values
####################
# UI: Value
## Expression
raw_value_spstr: str = bl_cache.BLField('0.0')
@ -186,6 +218,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
)
# UI: Info
show_name_selector: bool = bl_cache.BLField(False)
show_func_ui: bool = bl_cache.BLField(True)
show_info_columns: bool = bl_cache.BLField(False)
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
@ -207,25 +240,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
def raw_max_sp(self) -> spux.SympyExpr:
return self._parse_expr_str(self.raw_max_spstr)
####################
# - Computed Unit
####################
@bl_cache.cached_bl_property(depends_on={'active_unit'})
def unit(self) -> spux.Unit | None:
"""Gets the current active unit.
Returns:
The current active `sympy` unit.
If the socket expression is unitless, this returns `None`.
"""
if self.active_unit is not None:
return spux.unit_str_to_unit(self.active_unit)
return None
prev_unit: str | None = bl_cache.BLField(None)
####################
# - Prop-Change Callback
####################
@ -272,7 +286,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
raise ValueError(msg)
# Parse Symbols
if expr.free_symbols and not expr.free_symbols.issubset(self.symbols):
if expr.free_symbols and not expr.free_symbols.issubset(self.sp_symbols):
msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}'
raise ValueError(msg)
@ -320,7 +334,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
expr = sp.sympify(
expr_spstr,
locals={sym.name: sym for sym in self.symbols},
locals={sym.name: sym.sp_symbol_matsym for sym in self.symbols},
strict=False,
convert_xor=True,
).subs(spux.UNIT_BY_SYMBOL)
@ -562,11 +576,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
if self.symbols:
return ct.FuncFlow(
func=sp.lambdify(
self.sorted_symbols,
spux.scale_to_unit(self.value, self.unit),
self.sorted_sp_symbols,
spux.strip_unit_system(self.value),
'jax',
),
func_args=[spux.MathType.from_expr(sym) for sym in self.sorted_symbols],
func_args=list(self.sorted_symbols),
supports_jax=True,
)
@ -578,7 +592,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
return ct.FuncFlow(
func=lambda v: v,
func_args=[
self.physical_type if self.physical_type is not None else self.mathtype
sim_symbols.SimSymbol.from_expr(
sim_symbols.SimSymbolName.Constant, self.value, self.unit_factor
)
],
supports_jax=True,
)
@ -597,8 +613,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
## -> NOTE: func_args must have the same symbol order as was lambdified.
if self.symbols:
return ct.ParamsFlow(
func_args=self.sorted_symbols,
symbols=self.symbols,
func_args=[sym.sp_symbol_phy for sym in self.sorted_symbols],
symbols=self.sorted_symbols,
)
# Constant
@ -618,24 +634,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
"""
output_sim_sym = (
sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
),
output_sym = sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
)
# Constant
## -> The input SimSymbols become continuous dimensional indices.
## -> All domain validity information is defined on the SimSymbol keys.
if self.symbols:
return ct.InfoFlow(
dims={sim_sym: None for sim_sym in self.active_symbols},
output=output_sim_sym,
dims={sym: None for sym in self.sorted_symbols},
output=output_sym,
)
# Constant
return ct.InfoFlow(output=output_sim_sym)
## -> We only need the output symbol to describe the raw data.
return ct.InfoFlow(output=output_sym)
####################
# - FlowKind: Capabilities
@ -645,6 +664,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
return ct.CapabilitiesFlow(
socket_type=self.socket_type,
active_kind=self.active_kind,
allow_out_to_in={ct.FlowKind.Func: ct.FlowKind.Value},
)
####################
@ -795,7 +815,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
col = split.column()
col.alignment = 'RIGHT'
for sym in self.symbols:
col.label(text=spux.pretty_symbol(sym))
col.label(text=sym.def_label)
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
"""Draw the socket body for a simple, uniform range of values between two values/expressions.
@ -840,14 +860,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
Uses `draw_value` to draw the base UI
"""
if self.show_func_ui:
# Output Name Selector
## -> The name of the output
col.prop(self, self.blfields['output_name'], text='')
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
# Non-Symbolic: Size/Mathtype Selector
## -> Symbols imply str expr input.
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
@ -861,15 +873,30 @@ class ExprBLSocket(base.MaxwellSimSocket):
## -> Draws the UI appropriate for the above choice of constraints.
self.draw_value(col)
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
# Symbol UI
## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO
# Output Name Selector
## -> The name of the output
if self.show_name_selector:
row = col.row()
row.prop(self, self.blfields['output_name'], text='Name')
####################
# - UI: InfoFlow
####################
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
if self.active_kind == ct.FlowKind.Func and self.show_info_columns:
"""Visualize the `InfoFlow` information passing through the socket."""
if (
self.active_kind == ct.FlowKind.Func
and self.show_info_columns
and self.is_linked
):
row = col.row()
box = row.box()
grid = box.grid_flow(
@ -881,38 +908,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
)
# Dimensions
for dim in info.dims:
dim_idx = info.dims[dim]
grid.label(text=dim.name_pretty)
for dim_name_pretty, dim_label_info in info.dim_labels.items():
grid.label(text=dim_name_pretty)
if InfoDisplayCol.Length in self.info_columns:
grid.label(text=str(len(dim_idx)))
grid.label(text=dim_label_info['length'])
if InfoDisplayCol.MathType in self.info_columns:
grid.label(text=spux.MathType.to_str(dim_idx.mathtype))
grid.label(text=dim_label_info['mathtype'])
if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=spux.sp_to_str(dim_idx.unit))
grid.label(text=dim_label_info['unit'])
# Outputs
grid.label(text=info.output.name_pretty)
if InfoDisplayCol.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput)
if InfoDisplayCol.MathType in self.info_columns:
grid.label(
text=(
spux.MathType.to_str(info.output.mathtype)
+ (
'ˣ'.join(
[
unicode_superscript(out_axis)
for out_axis in info.output.shape
]
)
if info.output.shape
else ''
)
)
)
grid.label(text=info.output.def_label)
if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
grid.label(text=info.output.unit_label)
####################
@ -926,7 +938,7 @@ class ExprSocketDef(base.SocketDef):
ct.FlowKind.Array,
ct.FlowKind.Func,
] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
# Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -948,9 +960,15 @@ class ExprSocketDef(base.SocketDef):
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# UI
show_name_selector: bool = False
show_func_ui: bool = True
show_info_columns: bool = False
@property
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
"""Default symbols as an unordered set."""
return {sym.sp_symbol_matsym for sym in self.default_symbols}
####################
# - Parse Unit and/or Physical Type
####################
@ -1149,12 +1167,13 @@ class ExprSocketDef(base.SocketDef):
raise ValueError(msg)
# Coerce from Infinite
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if isinstance(bound, spux.SympyType):
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None:
self.default_min = new_bounds[0]
@ -1194,9 +1213,9 @@ class ExprSocketDef(base.SocketDef):
def symbols_value(self) -> typ.Self:
if (
self.default_value.free_symbols
and not self.default_value.free_symbols.issubset(self.symbols)
and not self.default_value.free_symbols.issubset(self.sp_symbols)
):
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.symbols}'
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.sp_symbols}'
raise ValueError(msg)
return self
@ -1227,7 +1246,7 @@ class ExprSocketDef(base.SocketDef):
bl_socket.size = self.size
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
bl_socket.active_symbols = self.symbols
bl_socket.symbols = self.default_symbols
# FlowKind.Value
## -> We must take units into account when setting bl_socket.value
@ -1252,6 +1271,7 @@ class ExprSocketDef(base.SocketDef):
# UI
bl_socket.show_func_ui = self.show_func_ui
bl_socket.show_info_columns = self.show_info_columns
bl_socket.show_name_selector = self.show_name_selector
# Info Draw
bl_socket.use_info_draw = True

View File

@ -389,14 +389,15 @@ class BLField:
Reset by setting the descriptor to `Signal.ResetStrSearch`.
"""
cached_items = self.bl_prop_str_search.read_nonpersist(_self)
if cached_items is not Signal.CacheNotReady:
if cached_items is Signal.CacheEmpty:
computed_items = self.str_cb(_self, context, edit_text)
self.bl_prop_str_search.write_nonpersist(_self, computed_items)
return computed_items
return cached_items
return []
return self.str_cb(_self, context, edit_text)
# cached_items = self.bl_prop_str_search.read_nonpersist(_self)
# if cached_items is not Signal.CacheNotReady:
# if cached_items is Signal.CacheEmpty:
# computed_items = self.str_cb(_self, context, edit_text)
# self.bl_prop_str_search.write_nonpersist(_self, computed_items)
# return computed_items
# return cached_items
# return []
def safe_enum_cb(
self, _self: bl_instance.BLInstance, context: bpy.types.Context

View File

@ -28,11 +28,13 @@ Attributes:
import enum
import functools
import sys
import typing as typ
from fractions import Fraction
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
@ -144,6 +146,21 @@ class MathType(enum.StrEnum):
complex: MathType.Complex,
}[dtype]
@staticmethod
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
"""Deduce the MathType corresponding to a JAX array.
We go about this by leveraging that:
- `data` is of a homogeneous type.
- `data.item(0)` returns a single element of the array w/pure-python type.
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
Notes:
Should also work with numpy arrays.
"""
return MathType.from_pytype(type(data.item(0)))
@staticmethod
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
if isinstance(obj, bool | int | Fraction | float | complex):
@ -173,6 +190,39 @@ class MathType(enum.StrEnum):
MT.Complex: sp.Complexes,
}[self]
@property
def inf_finite(self) -> type:
"""Opinionated finite representation of "infinity" within this `MathType`.
These are chosen using `sys.maxsize` and `sys.float_info`.
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
Notes:
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
"""
MT = MathType
Z = MT.Integer
R = MT.Integer
return {
MT.Integer: (-sys.maxsize, sys.maxsize),
MT.Rational: (
Fraction(Z.inf_finite[0], 1),
Fraction(Z.inf_finite[1], 1),
),
MT.Real: -(sys.float_info.min, sys.float_info.max),
MT.Complex: (
complex(R.inf_finite[0], R.inf_finite[0]),
complex(R.inf_finite[1], R.inf_finite[1]),
),
}[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
@ -192,6 +242,10 @@ class MathType(enum.StrEnum):
MathType.Complex: '',
}[value]
@property
def label_pretty(self) -> str:
return MathType.to_str(self)
@staticmethod
def to_name(value: typ.Self) -> str:
return MathType.to_str(value)
@ -819,14 +873,15 @@ def sp_to_str(sp_obj: SympyExpr) -> str:
A string representing the expression for human use.
_The string is not re-encodable to the expression._
"""
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
def pretty_symbol(sym: sp.Symbol) -> str:
return f'{sym.name}' + (
''
if sym.is_complex
else ('' if sym.is_real else ('' if sym.is_integer else '?'))
''
if sym.is_integer
else ('' if sym.is_real else ('' if sym.is_complex else '?'))
)
@ -1039,20 +1094,24 @@ class PhysicalType(enum.StrEnum):
PT.LumIntensity: spu.candela,
PT.LumFlux: spu.candela * spu.steradian,
PT.Illuminance: spu.candela / spu.meter**2,
# Optics
PT.OrdinaryWaveVector: terahertz,
PT.AngularWaveVector: spu.radian * terahertz,
}[self]
@functools.cached_property
def valid_units(self) -> list[Unit]:
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
Notes:
The order in which valid units are declared is the exact same order that UI dropdowns display them.
**Altering the order of units breaks backwards compatibility**.
"""
PT = PhysicalType
return {
PT.NonPhysical: [None],
# Global
PT.Time: [
femtosecond,
spu.picosecond,
femtosecond,
spu.nanosecond,
spu.microsecond,
spu.millisecond,
@ -1070,11 +1129,11 @@ class PhysicalType(enum.StrEnum):
],
PT.Freq: (
_valid_freqs := [
terahertz,
spu.hertz,
kilohertz,
megahertz,
gigahertz,
terahertz,
petahertz,
exahertz,
]
@ -1083,10 +1142,10 @@ class PhysicalType(enum.StrEnum):
# Cartesian
PT.Length: (
_valid_lens := [
spu.micrometer,
spu.nanometer,
spu.picometer,
spu.angstrom,
spu.nanometer,
spu.micrometer,
spu.millimeter,
spu.centimeter,
spu.meter,
@ -1102,24 +1161,24 @@ class PhysicalType(enum.StrEnum):
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
PT.Mass: [
spu.kilogram,
spu.electron_rest_mass,
spu.dalton,
spu.microgram,
spu.milligram,
spu.gram,
spu.kilogram,
spu.metric_ton,
],
PT.Force: [
spu.kg * spu.meter / spu.second**2,
nanonewton,
micronewton,
nanonewton,
millinewton,
spu.newton,
spu.kg * spu.meter / spu.second**2,
],
PT.Pressure: [
millibar,
spu.bar,
millibar,
spu.pascal,
hectopascal,
spu.atmosphere,
@ -1129,8 +1188,8 @@ class PhysicalType(enum.StrEnum):
],
# Energy
PT.Work: [
spu.electronvolt,
spu.joule,
spu.electronvolt,
],
PT.Power: [
spu.watt,
@ -1194,18 +1253,17 @@ class PhysicalType(enum.StrEnum):
PT.Illuminance: [
spu.candela / spu.meter**2,
],
# Optics
PT.OrdinaryWaveVector: _valid_freqs,
PT.AngularWaveVector: [spu.radian * _unit for _unit in _valid_freqs],
}[self]
@staticmethod
def from_unit(unit: Unit) -> list[Unit]:
def from_unit(unit: Unit, optional: bool = False) -> list[Unit] | None:
for physical_type in list(PhysicalType):
if unit in physical_type.valid_units:
return physical_type
## TODO: Optimize
if optional:
return None
msg = f'Could not determine PhysicalType for {unit}'
raise ValueError(msg)
@ -1400,20 +1458,9 @@ def sympy_to_python(
####################
# - Convert to Unit System
####################
def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
def strip_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None = None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
@ -1427,6 +1474,19 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> Symp
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array:

View File

@ -18,7 +18,6 @@
import enum
import functools
import time
import typing as typ
import jax
@ -34,7 +33,7 @@ import seaborn as sns
from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger
mplstyle.use('fast') ## TODO: Does this do anything?
# mplstyle.use('fast') ## TODO: Does this do anything?
sns.set_theme()
log = logger.get(__name__)
@ -59,6 +58,9 @@ class Colormap(enum.StrEnum):
Viridis = enum.auto()
Grayscale = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
return {
@ -139,7 +141,9 @@ def rgba_image_from_2d_map(
####################
@functools.lru_cache(maxsize=16)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure(figsize=[width_inches, height_inches], dpi=dpi)
fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
)
canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
ax = fig.add_subplot()
@ -152,66 +156,53 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
# - Plotters
####################
# () ->
def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.last_dim
y_sym = info.output
def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
ax.boxplot([data])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.boxplot([data[y_sym]])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_sym = info.last_dim
y_sym = info.output
def plot_bar(data, ax: mpl_ax.Axis) -> None:
x_sym, heights_sym = list(data.keys())
p = ax.bar(info.dims[x_sym], data)
p = ax.bar(data[x_sym], data[heights_sym])
ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.set_xlabel(heights_sym.plot_label)
# () ->
def plot_curve_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.last_dim
y_sym = info.output
def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
ax.plot(info.dims[x_sym].realize_array.values, data)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.plot(data[x_sym], data[y_sym])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
def plot_points_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.last_dim
y_sym = info.output
def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
ax.scatter(x_sym.realize_array.values, data)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.scatter(data[x_sym], data[y_sym])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
# (, ) ->
def plot_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.first_dim
y_sym = info.output
def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, label_sym, y_sym = list(data.keys())
for i, category in enumerate(info.dims[info.last_dim]):
ax.plot(info.dims[x_sym], data[:, i], label=category)
for i, label in enumerate(data[label_sym]):
ax.plot(data[x_sym], data[y_sym][:, i], label=label)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()
@ -220,12 +211,10 @@ def plot_curves_2d(
def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.first_dim
y_sym = info.output
x_sym, _, y_sym = list(data.keys())
shared_x_idx = info.dims[info.last_dim]
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()
@ -235,11 +224,9 @@ def plot_filled_curves_2d(
def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None:
x_sym = info.first_dim
y_sym = info.last_dim
c_sym = info.output
x_sym, y_sym, c_sym = list(data.keys())
heatmap = ax.imshow(data, aspect='equal', interpolation='none')
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
ax.figure.colorbar(heatmap, cax=ax)
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')

View File

@ -99,6 +99,7 @@ class TypeID(enum.StrEnum):
SympyType: str = '!type=sympytype'
SympyExpr: str = '!type=sympyexpr'
SocketDef: str = '!type=socketdef'
SimSymbol: str = '!type=simsymbol'
ManagedObj: str = '!type=managedobj'
@ -161,11 +162,12 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
if hasattr(_type, 'parse_as_msgspec') and (
is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj]
is_representation(obj)
and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj, TypeID.SimSymbol]
):
return _type.parse_as_msgspec(obj)
msg = f'Can\'t decode "{obj}" to type {type(obj)}'
msg = f'can\'t decode "{obj}" to type {type(obj)}'
raise NotImplementedError(msg)

View File

@ -14,36 +14,60 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import enum
import functools
import string
import sys
import typing as typ
from fractions import Fraction
import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
import sympy.physics.units as spu
from . import extra_sympy_units as spux
from . import logger, serialize
int_min = -(2**64)
int_max = 2**64
float_min = sys.float_info.min
float_max = sys.float_info.max
log = logger.get(__name__)
def unicode_superscript(n: int) -> str:
"""Transform an integer into its unicode-based superscript character."""
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
####################
# - Simulation Symbol Names
####################
_l = ''
_it_lower = iter(string.ascii_lowercase)
class SimSymbolName(enum.StrEnum):
# Lower
LowerA = enum.auto()
LowerB = enum.auto()
LowerC = enum.auto()
LowerD = enum.auto()
LowerI = enum.auto()
LowerT = enum.auto()
LowerX = enum.auto()
LowerY = enum.auto()
LowerZ = enum.auto()
# Generic
Constant = enum.auto()
Expr = enum.auto()
Data = enum.auto()
# Ascii Letters
while True:
try:
globals()['_l'] = next(globals()['_it_lower'])
except StopIteration:
break
locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
# Greek Letters
LowerTheta = enum.auto()
LowerPhi = enum.auto()
# Fields
Ex = enum.auto()
@ -64,18 +88,15 @@ class SimSymbolName(enum.StrEnum):
Wavelength = enum.auto()
Frequency = enum.auto()
Flux = enum.auto()
PermXX = enum.auto()
PermYY = enum.auto()
PermZZ = enum.auto()
Flux = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
# Generic
Expr = enum.auto()
####################
# - UI
####################
@ -109,49 +130,66 @@ class SimSymbolName(enum.StrEnum):
@property
def name(self) -> str:
SSN = SimSymbolName
return {
# Lower
SSN.LowerA: 'a',
SSN.LowerB: 'b',
SSN.LowerC: 'c',
SSN.LowerD: 'd',
SSN.LowerI: 'i',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
# Fields
SSN.Ex: 'Ex',
SSN.Ey: 'Ey',
SSN.Ez: 'Ez',
SSN.Hx: 'Hx',
SSN.Hy: 'Hy',
SSN.Hz: 'Hz',
SSN.Er: 'Ex',
SSN.Etheta: 'Ey',
SSN.Ephi: 'Ez',
SSN.Hr: 'Hx',
SSN.Htheta: 'Hy',
SSN.Hphi: 'Hz',
# Optics
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
SSN.Flux: 'flux',
SSN.PermXX: 'eps_xx',
SSN.PermYY: 'eps_yy',
SSN.PermZZ: 'eps_zz',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
# Generic
SSN.Expr: 'expr',
}[self]
return (
# Ascii Letters
{SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
| {
SSN[f'Upper{letter.upper()}']: letter.upper()
for letter in string.ascii_lowercase
}
| {
# Generic
SSN.Constant: 'constant',
SSN.Expr: 'expr',
SSN.Data: 'data',
# Greek Letters
SSN.LowerTheta: 'theta',
SSN.LowerPhi: 'phi',
# Fields
SSN.Ex: 'Ex',
SSN.Ey: 'Ey',
SSN.Ez: 'Ez',
SSN.Hx: 'Hx',
SSN.Hy: 'Hy',
SSN.Hz: 'Hz',
SSN.Er: 'Ex',
SSN.Etheta: 'Ey',
SSN.Ephi: 'Ez',
SSN.Hr: 'Hx',
SSN.Htheta: 'Hy',
SSN.Hphi: 'Hz',
# Optics
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
SSN.PermXX: 'eps_xx',
SSN.PermYY: 'eps_yy',
SSN.PermZZ: 'eps_zz',
SSN.Flux: 'flux',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
}
)[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
# Generic
# Greek Letters
SSN.LowerTheta: 'θ',
SSN.LowerPhi: 'φ',
# Fields
SSN.Etheta: '',
SSN.Ephi: '',
SSN.Hr: 'Hr',
SSN.Htheta: '',
SSN.Hphi: '',
# Optics
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
SSN.PermXX: 'ε_xx',
SSN.PermYY: 'ε_yy',
SSN.PermZZ: 'ε_zz',
}.get(self, self.name)
@ -173,8 +211,7 @@ def mk_interval(
)
@dataclasses.dataclass(kw_only=True, frozen=True)
class SimSymbol:
class SimSymbol(pyd.BaseModel):
"""A declarative representation of a symbolic variable.
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
@ -183,6 +220,8 @@ class SimSymbol:
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
"""
model_config = pyd.ConfigDict(frozen=True)
sym_name: SimSymbolName
mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
@ -191,6 +230,9 @@ class SimSymbol:
## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way.
unit: spux.Unit | None = None
## -> TODO: We currently allowing units that don't match PhysicalType
## -> -- In particular, NonPhysical w/units means "unknown units".
## -> -- This is essential for the Scientific Constant Node.
# Size
## -> All SimSymbol sizes are "2D", but interpreted by convention.
@ -205,43 +247,76 @@ class SimSymbol:
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI.
is_constant: bool = False
interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0, 1)
interval_finite_re: tuple[float, float] = (0.0, 1.0)
interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False)
interval_finite_im: tuple[float, float] = (0, 1)
interval_finite_im: tuple[float, float] = (0.0, 1.0)
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
####################
# - Properties
# - Labels
####################
@property
@functools.cached_property
def name(self) -> str:
"""Usable name for the symbol."""
return self.sym_name.name
@property
@functools.cached_property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
return self.sym_name.name_pretty
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
@property
@functools.cached_property
def mathtype_size_label(self) -> str:
"""Pretty label that shows both mathtype and size."""
return f'{self.mathtype.label_pretty}' + (
'ˣ'.join([unicode_superscript(out_axis) for out_axis in self.shape])
if self.shape
else ''
)
@functools.cached_property
def unit_label(self) -> str:
"""Pretty unit label, which is an empty string when there is no unit."""
return spux.sp_to_str(self.unit) if self.unit is not None else ''
@functools.cached_property
def def_label(self) -> str:
"""Pretty definition label, exposing the symbol definition."""
return f'{self.name_pretty} | {self.unit_label}{self.mathtype_size_label}'
## TODO: Domain of validity from self.domain?
@functools.cached_property
def plot_label(self) -> str:
"""Pretty plot-oriented label."""
return f'{self.name_pretty}' + (
f'({self.unit})' if self.unit is not None else ''
)
@property
####################
# - Computed Properties
####################
@functools.cached_property
def unit_factor(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return self.unit if self.unit is not None else sp.S(1)
@property
@functools.cached_property
def size(self) -> tuple[int, ...] | None:
return {
(1, 1): spux.NumberSize1D.Scalar,
(2, 1): spux.NumberSize1D.Vec2,
(3, 1): spux.NumberSize1D.Vec3,
(4, 1): spux.NumberSize1D.Vec4,
}.get((self.rows, self.cols))
@functools.cached_property
def shape(self) -> tuple[int, ...]:
match (self.rows, self.cols):
case (1, 1):
@ -253,7 +328,12 @@ class SimSymbol:
case (_, _):
return (self.rows, self.cols)
@property
@functools.cached_property
def shape_len(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return len(self.shape)
@functools.cached_property
def domain(self) -> sp.Interval | sp.Set:
"""Return the scalar domain of valid values for each element of the symbol.
@ -303,11 +383,31 @@ class SimSymbol:
),
)
@functools.cached_property
def valid_domain_value(self) -> spux.SympyExpr:
"""A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
match (self.domain.start.is_finite, self.domain.end.is_finite):
case (True, True):
if self.mathtype is spux.MathType.Integer:
return (self.domain.start + self.domain.end) // 2
return (self.domain.start + self.domain.end) / 2
case (True, False):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.start + one
case (False, True):
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
return self.domain.end - one
case (False, False):
return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
####################
# - Properties
####################
@property
def sp_symbol(self) -> sp.Symbol:
@functools.cached_property
def sp_symbol(self) -> sp.Symbol | sp.ImmutableMatrix:
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
@ -352,7 +452,82 @@ class SimSymbol:
elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True}
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
# Scalar: Return Symbol
if self.rows == 1 and self.cols == 1:
return sp.Symbol(self.sym_name.name, **mathtype_kwargs)
# Vector|Matrix: Return Matrix of Symbols
## -> MatrixSymbol doesn't support assumptions.
## -> This little construction does.
return sp.ImmutableMatrix(
[
[
sp.Symbol(self.sym_name.name + f'_{row}{col}', **mathtype_kwargs)
for col in range(self.cols)
]
for row in range(self.rows)
]
)
@functools.cached_property
def sp_symbol_matsym(self) -> sp.Symbol | sp.MatrixSymbol:
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`, w/variable shape support.
To preserve as many assumptions as possible, `self.sp_symbol` returns a matrix of individual `sp.Symbol`s whenever the `SimSymbol` is non-scalar.
However, this isn't always the most useful representation: For example, if the intention is to use a shaped symbolic variable as an argument to `sympy.lambdify()`, one would have to flatten each individual `sp.Symbol` and pass each matrix element as a single element, greatly complicating things like broadcasting.
For this reason, this property is provided.
Whenever the `SimSymbol` is scalar, it works identically to `self.sp_symbol`.
However, when the `SimSymbol` is shaped, an appropriate `sp.MatrixSymbol` is returned instead.
Notes:
`sp.MatrixSymbol` doesn't support assumptions.
As such, things like deduction of `MathType` from expressions involving a matrix symbol simply won't work.
"""
if self.shape_len == 0:
return self.sp_symbol
return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
@functools.cached_property
def sp_symbol_phy(self) -> spux.SympyExpr:
"""Physical symbol containing `self.sp_symbol` multiplied by `self.unit`."""
return self.sp_symbol * self.unit_factor
@functools.cached_property
def expr_info(self) -> dict[str, typ.Any]:
"""Generate keyword arguments for an ExprSocket, whose output values will be guaranteed to conform to this `SimSymbol`.
Notes:
Before use, `active_kind=ct.FlowKind.Range` can be added to make the `ExprSocket`.
Default values are set for both `Value` and `Range`.
To this end, `self.domain` is used.
Since `ExprSocketDef` allows the use of infinite bounds for `default_min` and `default_max`, we defer the decision of how to treat finite-fallback to the `ExprSocketDef`.
"""
if self.size is not None:
if self.unit in self.physical_type.valid_units:
return {
'output_name': self.sym_name,
# Socket Interface
'size': self.size,
'mathtype': self.mathtype,
'physical_type': self.physical_type,
# Defaults: Units
'default_unit': self.unit,
'default_symbols': [],
# Defaults: FlowKind.Value
'default_value': self.conform(
self.valid_domain_value, strip_unit=True
),
# Defaults: FlowKind.Range
'default_min': self.domain.start,
'default_max': self.domain.end,
}
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
raise NotImplementedError(msg)
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
raise NotImplementedError(msg)
####################
# - Operations
@ -373,7 +548,7 @@ class SimSymbol:
cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_re'),
interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'),
@ -381,24 +556,199 @@ class SimSymbol:
interval_closed_im=get_attr('interval_closed_im'),
)
def set_finite_domain( # noqa: PLR0913
self,
start: int | float,
end: int | float,
start_closed: bool = True,
end_closed: bool = True,
start_im: bool = float,
end_im: bool = float,
start_closed_im: bool = True,
end_closed_im: bool = True,
) -> typ.Self:
"""Update the symbol with a finite range."""
closed_re = (start_closed, end_closed)
closed_im = (start_closed_im, end_closed_im)
match self.mathtype:
case spux.MathType.Integer:
return self.update(
interval_finite_z=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Rational:
return self.update(
interval_finite_q=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Real:
return self.update(
interval_finite_re=(start, end),
interval_inf=(False, False),
interval_closed=closed_re,
)
case spux.MathType.Complex:
return self.update(
interval_finite_re=(start, end),
interval_finite_im=(start_im, end_im),
interval_inf=(False, False),
interval_closed=closed_re,
interval_closed_im=closed_im,
)
def set_size(self, rows: int, cols: int) -> typ.Self:
return self.update(rows=rows, cols=cols)
def conform(
self, sp_obj: spux.SympyType, strip_unit: bool = False
) -> spux.SympyType:
"""Conform a sympy object to the properties of this `SimSymbol`, if possible.
To achieve this, a number of operations may be performed:
- **Unit Conversion**: If the object has no units, but should, multiply by `self.unit`. If the object has units, but shouldn't, strip them. Otherwise, convert its unit to `self.unit`.
- **Broadcast Expansion**: If the object is a scalar, but the `SimSymbol` is shaped, then an `sp.ImmutableMatrix` is returned with the scalar at each position.
Returns:
A transformed sympy object guaranteed usable as a particular value of this `SimSymbol` variable.
Raises:
ValueError: If the units of `sp_obj` can't be cleanly converted to `self.unit`.
"""
res = sp_obj
# Unit Conversion
match (spux.uses_units(sp_obj), self.unit is not None):
case (True, True):
res = spux.scale_to_unit(sp_obj, self.unit) * self.unit
case (False, True):
res = sp_obj * self.unit
case (True, False):
res = spux.strip_unit_system(sp_obj)
if strip_unit:
res = spux.strip_unit_system(sp_obj)
# Broadcast Expansion
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
return res
def scale(
self, sp_obj: spux.SympyType, use_jax_array: bool = True
) -> int | float | complex | jtyp.Inexact[jtyp.Array, '...']:
"""Remove all symbolic elements from the conformed `sp_obj`, preparing it for use in contexts that don't support unrealized symbols.
On top of `self.conform()`, a number of operations are performed.
- **Unit Stripping**: The `self.unit` of the expression returned by `self.conform()` will be stripped.
- **Sympy to Python**: The now symbol-less expression will be converted to either a pure Python type, or to a `jax` array (if `use_jax_array` is set).
Notes:
When creating numerical functions of expressions using `.lambdify`, `self.scale()` **must be used** in place of `self.conform()` before the parameterized expression is used.
Returns:
A "raw" (pure Python / jax array) type guaranteed usable as a particular **numerical** value of this `SymSymbol` variable.
"""
# Conform
res = self.conform(sp_obj)
# Strip Units
res = spux.scale_to_unit(sp_obj, self.unit)
# Sympy to Python
res = spux.sympy_to_python(res, use_jax_array=use_jax_array)
return res # noqa: RET504
@staticmethod
def from_expr(
sym_name: SimSymbolName,
expr: spux.SympyExpr,
unit_expr: spux.SympyExpr,
) -> typ.Self:
"""Deduce a `SimSymbol` that matches the output of a given expression (and unit expression).
This is an essential method, allowing for the ded
Notes:
`PhysicalType` **cannot be set** from an expression in the generic sense.
Therefore, the trick of using `NonPhysical` with non-`None` unit to denote unknown `PhysicalType` is used in the output.
All intervals are kept at their defaults.
Parameters:
sym_name: The `SimSymbolName` to set to the resulting symbol.
expr: The unit-aware expression to parse and encapsulate as a symbol.
unit_expr: A dimensional analysis expression (set to `1` to make the resulting symbol unitless).
Fundamentally, units are just the variables of scalar terms.
'1' for unitless terms are, in the dimanyl sense, constants.
Doing it like this may be a little messy, but is accurate.
Returns:
A fresh new `SimSymbol` that tries to match the given expression (and unit expression) well enough to be usable in place of it.
"""
# MathType from Expr Assumptions
## -> All input symbols have assumptions, because we are very pedantic.
## -> Therefore, we should be able to reconstruct the MathType.
mathtype = spux.MathType.from_expr(expr)
# PhysicalType as "NonPhysical"
## -> 'unit' still applies - but we can't guarantee a PhysicalType will.
## -> Therefore, this is what we gotta do.
physical_type = spux.PhysicalType.NonPhysical
# Rows/Cols from Expr (if Matrix)
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
return SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
sym_name=sym_name,
mathtype=mathtype,
physical_type=physical_type,
unit=unit_expr if unit_expr != 1 else None,
rows=rows,
cols=cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
)
####################
# - Serialization
####################
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
"""Transforms this `SimSymbol` into an object that can be natively serialized by `msgspec`.
Notes:
Makes use of `pydantic.BaseModel.model_dump()` to cast any special fields into a serializable format.
If this method is failing, check that `pydantic` can actually cast all the fields in your model.
Returns:
A particular `list`, with two elements:
1. The `serialize`-provided "Type Identifier", to differentiate this list from generic list.
2. A dictionary containing simple Python types, as cast by `pydantic`.
"""
return [serialize.TypeID.SimSymbol, self.__class__.__name__, self.model_dump()]
@staticmethod
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
"""Transforms an object made by `self.dump_as_msgspec()` into an instance of `SimSymbol`.
Notes:
The method presumes that the deserialized object produced by `msgspec` perfectly matches the object originally created by `self.dump_as_msgspec()`.
This is a **mostly robust** presumption, as `pydantic` attempts to be quite consistent in how to interpret types with almost identical semantics.
Still, yet-unknown edge cases may challenge these presumptions.
Returns:
A new instance of `SimSymbol`, initialized using the `model_dump()` dictionary.
"""
return SimSymbol(**obj[2])
####################
# - Common Sim Symbols
@ -453,14 +803,10 @@ class CommonSimSymbol(enum.StrEnum):
Wavelength = enum.auto()
Frequency = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
Flux = enum.auto()
WaveVecX = enum.auto()
WaveVecY = enum.auto()
WaveVecZ = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
####################
# - UI
@ -549,10 +895,10 @@ class CommonSimSymbol(enum.StrEnum):
if eh == 'e'
else spux.PhysicalType.HField,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_finite_re=(0, float_max),
interval_inf_re=(False, True),
interval_closed_re=(True, False),
interval_finite_im=(sys.float_info.min, sys.float_info.max),
interval_finite_im=(float_min, float_max),
interval_inf_im=(True, True),
)
@ -575,7 +921,7 @@ class CommonSimSymbol(enum.StrEnum):
sym_name=self.name,
physical_type=spux.PhysicalType.Time,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_finite_re=(0, float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
@ -592,19 +938,13 @@ class CommonSimSymbol(enum.StrEnum):
CSS.FieldHr: sym_field('h'),
CSS.FieldHtheta: sym_field('h'),
CSS.FieldHphi: sym_field('h'),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
# Optics
CSS.Wavelength: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
unit=unit,
interval_finite=(0, sys.float_info.max),
interval_finite=(0, float_max),
interval_inf=(False, True),
interval_closed=(False, False),
),
@ -613,10 +953,30 @@ class CommonSimSymbol(enum.StrEnum):
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq,
unit=unit,
interval_finite=(0, sys.float_info.max),
interval_finite=(0, float_max),
interval_inf=(False, True),
interval_closed=(False, False),
),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
CSS.DiffOrderX: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
interval_finite=(int_min, int_max),
interval_inf=(True, True),
interval_closed=(False, False),
),
CSS.DiffOrderY: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
interval_finite=(int_min, int_max),
interval_inf=(True, True),
interval_closed=(False, False),
),
}[self]