refactor: revamped symbolic flow (inaccurate unit conversions)
parent
353a2c997e
commit
bcba444a8b
|
@ -18,8 +18,8 @@ from .array import ArrayFlow
|
|||
from .capabilities import CapabilitiesFlow
|
||||
from .flow_kinds import FlowKind
|
||||
from .info import InfoFlow
|
||||
from .lazy_range import RangeFlow, ScalingMode
|
||||
from .lazy_func import FuncFlow
|
||||
from .lazy_range import RangeFlow, ScalingMode
|
||||
from .params import ParamsFlow
|
||||
from .value import ValueFlow
|
||||
|
||||
|
|
|
@ -50,11 +50,6 @@ class ArrayFlow:
|
|||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def is_symbolic(self) -> bool:
|
||||
"""Always False, as ArrayFlows are never unrealized."""
|
||||
return False
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Outer length of the contained array."""
|
||||
return len(self.values)
|
||||
|
@ -196,5 +191,11 @@ class ArrayFlow:
|
|||
"""
|
||||
return self.rescale(lambda v: v, new_unit=new_unit)
|
||||
|
||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||
raise NotImplementedError
|
||||
def rescale_to_unit_system(self, unit_system: spux.UnitSystem | None) -> typ.Self:
|
||||
if unit_system is None:
|
||||
return self.values
|
||||
|
||||
return self.correct_unit(None).rescale(
|
||||
lambda v: spux.scale_to_unit_system(v * self.unit, unit_system),
|
||||
new_unit=spux.convert_to_unit_system(self.unit, unit_system),
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import dataclasses
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
from ..socket_types import SocketType
|
||||
from .flow_kinds import FlowKind
|
||||
|
@ -25,6 +26,7 @@ from .flow_kinds import FlowKind
|
|||
class CapabilitiesFlow:
|
||||
socket_type: SocketType
|
||||
active_kind: FlowKind
|
||||
allow_out_to_in: dict[FlowKind, FlowKind] = dataclasses.field(default_factory=dict)
|
||||
|
||||
is_universal: bool = False
|
||||
|
||||
|
@ -40,7 +42,13 @@ class CapabilitiesFlow:
|
|||
def is_compatible_with(self, other: typ.Self) -> bool:
|
||||
return other.is_universal or (
|
||||
self.socket_type == other.socket_type
|
||||
and self.active_kind == other.active_kind
|
||||
and (
|
||||
self.active_kind == other.active_kind
|
||||
or (
|
||||
other.active_kind in other.allow_out_to_in
|
||||
and self.active_kind == other.allow_out_to_in[other.active_kind]
|
||||
)
|
||||
)
|
||||
# == Constraint
|
||||
and all(
|
||||
name in other.must_match
|
||||
|
|
|
@ -67,8 +67,9 @@ class InfoFlow:
|
|||
default_factory=dict
|
||||
)
|
||||
|
||||
# Access
|
||||
@functools.cached_property
|
||||
def last_dim(self) -> sim_symbols.SimSymbol | None:
|
||||
def first_dim(self) -> sim_symbols.SimSymbol | None:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
|
@ -87,13 +88,24 @@ class InfoFlow:
|
|||
return list(self.dims.keys())[-1]
|
||||
return None
|
||||
|
||||
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
|
||||
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
|
||||
if idx > 0 and idx < len(self.dims) - 1:
|
||||
return list(self.dims.keys())[idx]
|
||||
return None
|
||||
|
||||
def dim_by_name(self, dim_name: str) -> int:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
"""
|
||||
return list(self.dims.keys()).index(dim)
|
||||
dims_with_name = [dim for dim in self.dims if dim.name == dim_name]
|
||||
if len(dims_with_name) == 1:
|
||||
return dims_with_name[0]
|
||||
|
||||
msg = f'Dim name {dim_name} not found in InfoFlow (or >1 found)'
|
||||
raise ValueError(msg)
|
||||
|
||||
# Information By-Dim
|
||||
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||
"""Whether the dim's index is continuous, and therefore index array.
|
||||
|
||||
|
@ -114,6 +126,23 @@ class InfoFlow:
|
|||
return isinstance(self.dims[dim], list)
|
||||
return False
|
||||
|
||||
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||
"""Whether the (int) dim has explicitly uniform indexing.
|
||||
|
||||
This is needed primarily to check whether a Fourier Transform can be meaningfully performed on the data over the dimension's axis.
|
||||
|
||||
In practice, we've decided that only `RangeFlow` really truly _guarantees_ uniform indexing.
|
||||
While `ArrayFlow` may be uniform in practice, it's a very expensive to check, and it's far better to enforce that the user perform that check and opt for a `RangeFlow` instead, at the time of dimension definition.
|
||||
"""
|
||||
return isinstance(self.dims[dim], RangeFlow) and self.dims[dim].scaling == 'lin'
|
||||
|
||||
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
"""
|
||||
return list(self.dims.keys()).index(dim)
|
||||
|
||||
####################
|
||||
# - Output: Contravariant Value
|
||||
####################
|
||||
|
@ -128,6 +157,49 @@ class InfoFlow:
|
|||
default_factory=dict
|
||||
)
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@functools.cached_property
|
||||
def input_mathtypes(self) -> tuple[spux.MathType, ...]:
|
||||
return tuple([dim.mathtype for dim in self.dims])
|
||||
|
||||
@functools.cached_property
|
||||
def output_mathtypes(self) -> tuple[spux.MathType, int, int]:
|
||||
return [self.output.mathtype for _ in range(len(self.output.shape) + 1)]
|
||||
|
||||
@functools.cached_property
|
||||
def order(self) -> tuple[spux.MathType, ...]:
|
||||
r"""The order of the tensor represented by this info.
|
||||
|
||||
While that sounds fancy and all, it boils down to:
|
||||
|
||||
$$
|
||||
\texttt{dims} + |\texttt{output}.\texttt{shape}|
|
||||
$$
|
||||
|
||||
Doing so characterizes the full dimensionality of the tensor, which also perfectly matches the length of the raw data's shape exactly.
|
||||
|
||||
Notes:
|
||||
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
|
||||
"""
|
||||
return len(self.input_mathtypes) + self.output_shape_len
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@functools.cached_property
|
||||
def dim_labels(self) -> dict[str, dict[str, str]]:
|
||||
"""Return a dictionary mapping pretty dim names to information oriented for columnar information display."""
|
||||
return {
|
||||
dim.name_pretty: {
|
||||
'length': str(len(dim_idx)) if dim_idx is not None else '∞',
|
||||
'mathtype': dim.mathtype.label_pretty,
|
||||
'unit': dim.unit_label,
|
||||
}
|
||||
for dim, dim_idx in self.dims.items()
|
||||
}
|
||||
|
||||
####################
|
||||
# - Operations: Dimensions
|
||||
####################
|
||||
|
@ -147,9 +219,11 @@ class InfoFlow:
|
|||
"""Slice a dimensional array by-index along a particular dimension."""
|
||||
return InfoFlow(
|
||||
dims={
|
||||
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||
_dim: (
|
||||
dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||
if _dim == dim
|
||||
else _dim
|
||||
else dim_idx
|
||||
)
|
||||
for _dim, dim_idx in self.dims.items()
|
||||
},
|
||||
output=self.output,
|
||||
|
@ -166,7 +240,7 @@ class InfoFlow:
|
|||
return InfoFlow(
|
||||
dims={
|
||||
(new_dim if _dim == old_dim else _dim): (
|
||||
new_dim_idx if _dim == old_dim else _dim
|
||||
new_dim_idx if _dim == old_dim else dim_idx
|
||||
)
|
||||
for _dim, dim_idx in self.dims.items()
|
||||
},
|
||||
|
@ -235,6 +309,26 @@ class InfoFlow:
|
|||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def operate_output(
|
||||
self,
|
||||
other: typ.Self,
|
||||
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
|
||||
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
|
||||
) -> spux.SympyExpr:
|
||||
if self.dims == other.dims:
|
||||
sym_name = sim_symbols.SimSymbolName.Expr
|
||||
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
|
||||
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
|
||||
|
||||
return InfoFlow(
|
||||
dims=self.dims,
|
||||
output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
msg = f'InfoFlow: operate_output cannot be used when dimensions are not identical ({self.dims} | {other.dims}).'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Operations: Fold
|
||||
####################
|
||||
|
|
|
@ -22,7 +22,7 @@ from types import MappingProxyType
|
|||
import jax
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .params import ParamsFlow
|
||||
|
||||
|
@ -244,17 +244,10 @@ class FuncFlow:
|
|||
"""
|
||||
|
||||
func: LazyFunction
|
||||
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
func_args: list[sim_symbols.SimSymbol] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, sim_symbols.SimSymbol] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
## TODO: Use SimSymbol instead of the MathType|PT union.
|
||||
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
|
||||
## -- SimSymbol has more semantic meaning, including a name.
|
||||
## -- If desired, SimSymbols could maybe even require a specific unit.
|
||||
## It could greatly simplify a whole lot of pain associated with func_args.
|
||||
supports_jax: bool = False
|
||||
|
||||
####################
|
||||
|
@ -315,17 +308,18 @@ class FuncFlow:
|
|||
def realize(
|
||||
self,
|
||||
params: ParamsFlow,
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> typ.Self:
|
||||
if self.supports_jax:
|
||||
return self.func_jax(
|
||||
*params.scaled_func_args(unit_system, symbol_values),
|
||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
||||
*params.scaled_func_args(self.func_args, symbol_values),
|
||||
*params.scaled_func_kwargs(self.func_args, symbol_values),
|
||||
)
|
||||
return self.func(
|
||||
*params.scaled_func_args(unit_system, symbol_values),
|
||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
||||
*params.scaled_func_args(self.func_kwargs, symbol_values),
|
||||
*params.scaled_func_kwargs(self.func_kwargs, symbol_values),
|
||||
)
|
||||
|
||||
####################
|
||||
|
|
|
@ -18,17 +18,18 @@ import dataclasses
|
|||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .lazy_func import FuncFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -62,7 +63,7 @@ class ScalingMode(enum.StrEnum):
|
|||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class RangeFlow:
|
||||
r"""Represents a spaced array using symbolic boundary expressions.
|
||||
r"""Represents a finite spaced array using symbolic boundary expressions.
|
||||
|
||||
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||
|
||||
|
@ -79,33 +80,76 @@ class RangeFlow:
|
|||
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
|
||||
|
||||
Attributes:
|
||||
start: An expression generating a scalar, unitless, complex value for the array's lower bound.
|
||||
start: An expression representing the unitless part of the finite, scalar, complex value for the array's lower bound.
|
||||
_Integer, rational, and real values are also supported._
|
||||
stop: An expression generating a scalar, unitless, complex value for the array's upper bound.
|
||||
start: An expression representing the unitless part of the finite, scalar, complex value for the array's upper bound.
|
||||
_Integer, rational, and real values are also supported._
|
||||
steps: The amount of steps (**inclusive**) to generate from `start` to `stop`.
|
||||
scaling: The method of distributing `step` values between the two endpoints.
|
||||
Generally, the linear default is sufficient.
|
||||
|
||||
unit: The unit of the generated array values
|
||||
unit: The unit to interpret the values as.
|
||||
|
||||
symbols: Set of variables from which `start` and/or `stop` are determined.
|
||||
"""
|
||||
|
||||
start: spux.ScalarUnitlessComplexExpr
|
||||
stop: spux.ScalarUnitlessComplexExpr
|
||||
steps: int
|
||||
steps: int = 0
|
||||
scaling: ScalingMode = ScalingMode.Lin
|
||||
|
||||
unit: spux.Unit | None = None
|
||||
|
||||
symbols: frozenset[spux.Symbol] = frozenset()
|
||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||
|
||||
# Helper Attributes
|
||||
pre_fourier_ideal_midpoint: spux.ScalarUnitlessComplexExpr | None = None
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
# - SimSymbol Interop
|
||||
####################
|
||||
@staticmethod
|
||||
def from_sym(
|
||||
sym: sim_symbols.SimSymbol,
|
||||
steps: int = 50,
|
||||
scaling: ScalingMode | str = ScalingMode.Lin,
|
||||
) -> typ.Self:
|
||||
if sym.domain.start.is_infinite or sym.domain.end.is_infinite:
|
||||
use_steps = 0
|
||||
else:
|
||||
use_steps = steps
|
||||
|
||||
return RangeFlow(
|
||||
start=sym.domain.start if sym.domain.start.is_finite else sp.S(-1),
|
||||
stop=sym.domain.end if sym.domain.end.is_finite else sp.S(1),
|
||||
steps=use_steps,
|
||||
scaling=ScalingMode(scaling),
|
||||
unit=sym.unit,
|
||||
)
|
||||
|
||||
def to_sym(
|
||||
self,
|
||||
sym_name: sim_symbols.SimSymbolName,
|
||||
) -> typ.Self:
|
||||
physical_type = spux.PhysicalType.from_unit(self.unit, optional=True)
|
||||
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=sym_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=(
|
||||
physical_type
|
||||
if physical_type is not None
|
||||
else spux.PhysicalType.NonPhysical
|
||||
),
|
||||
unit=self.unit,
|
||||
rows=1,
|
||||
cols=1,
|
||||
).set_domain(start=self.realize_start(), end=self.realize_end())
|
||||
|
||||
####################
|
||||
# - Symbols
|
||||
####################
|
||||
@functools.cached_property
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
The order is guaranteed to be **deterministic**.
|
||||
|
@ -115,10 +159,21 @@ class RangeFlow:
|
|||
"""
|
||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||
|
||||
@property
|
||||
def is_symbolic(self) -> bool:
|
||||
"""Whether the `RangeFlow` has unrealized symbols."""
|
||||
return len(self.symbols) > 0
|
||||
@functools.cached_property
|
||||
def sorted_sp_symbols(self) -> list[spux.Symbol]:
|
||||
"""Computes `sympy` symbols from `self.sorted_symbols`.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
return [sym.sp_symbol for sym in self.sorted_symbols]
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@functools.cached_property
|
||||
def unit_factor(self) -> spux.SympyExpr:
|
||||
return self.unit if self.unit is not None else sp.S(1)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Compute the length of the array that would be realized.
|
||||
|
@ -166,6 +221,14 @@ class RangeFlow:
|
|||
####################
|
||||
# - Methods
|
||||
####################
|
||||
@property
|
||||
def ideal_midpoint(self) -> spux.SympyExpr:
|
||||
return (self.stop + self.start) / 2
|
||||
|
||||
@property
|
||||
def ideal_range(self) -> spux.SympyExpr:
|
||||
return self.stop - self.start
|
||||
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
|
@ -181,8 +244,8 @@ class RangeFlow:
|
|||
new_pre_start = self.start if not reverse else self.stop
|
||||
new_pre_stop = self.stop if not reverse else self.start
|
||||
|
||||
new_start = rescale_func(new_pre_start * self.unit)
|
||||
new_stop = rescale_func(new_pre_stop * self.unit)
|
||||
new_start = rescale_func(new_pre_start * self.unit_factor)
|
||||
new_stop = rescale_func(new_pre_stop * self.unit_factor)
|
||||
|
||||
return RangeFlow(
|
||||
start=(
|
||||
|
@ -204,6 +267,99 @@ class RangeFlow:
|
|||
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@functools.cached_property
|
||||
def bound_fourier_transform(self):
|
||||
r"""Treat this `RangeFlow` it as an axis along which a fourier transform is being performed, such that its bounds scale according to the Nyquist Limit.
|
||||
|
||||
# Sampling Theory
|
||||
In general, the Fourier Transform is an operator that works on infinite, periodic, continuous functions.
|
||||
In this context alone it is a ideal transform (in terms of information retention), one which degrades quite gracefully in the face of practicalities like windowing (allowing them to apply analytically to non-periodic functions too).
|
||||
While often used to transform an axis of time to an axis of frequency, in general this transform "simply" extracts all repeating structures from a function.
|
||||
This is illustrated beautifully in the way that the output unit becomes the reciprocal of the input unit, which is the theory underlying why we say that measurements recieved as a reciprocal unit are in a "reciprocal space" (also called "k-space").
|
||||
|
||||
The real world is not so nice, of course, and as such we must generally make do with the Discrete Fourier Transform.
|
||||
Even with bounded discrete information, we can annoy many mathematicians by defining a DFT in such a way that "structure per thing" ($\frac{1}{\texttt{unit}}$) still makes sense to us (to them, maybe not).
|
||||
A DFT can still only retain the information given to it, but so long as we have enough "original structure", any "repeating structure" should be extractable with sufficient clarity to be useful.
|
||||
|
||||
What "sufficient clarity" means is the basis for the entire field of "sampling theory".
|
||||
The theoretical maximum for the "fineness of repetition" that is "noticeable" in the fourier-transformed of some data is characterized by a theoretical upper bound called the Nyquist Frequency / Limit, which ends up being half of the sampling rate.
|
||||
Thus, to determine bounds on the data, use of the Nyquist Limit is generally a good starting point.
|
||||
|
||||
Of course, when the discrete data comes from a discretization of a continuous signal, information from higher frequencies might still affect the discrete results.
|
||||
They do little else than cause havoc, though - best causing noise, and at worst causing structured artifacts (sometimes called "aliasing").
|
||||
Some of the first innovations in sampling theory were related to "anti-aliasing" filters, whose sole purpose is to try to remove informational frequencies above the Nyquist Limit of the discrete sensor.
|
||||
|
||||
In FDTD simulation, we're generally already ahead when it comes to aliasing, since our field values come from an already-discrete process.
|
||||
That is, unless we start overly "binning" (averaging over $n$ discrete observations); in this case, care should be taken to make sure that interesting results aren't merely unfortunately structured aliasing artifacts.
|
||||
|
||||
For more, see <https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem>
|
||||
|
||||
# Implementation
|
||||
In practice, our goal in `RangeFlow` is to compute the bounds of the index array along the fourier-transformed data axis.
|
||||
The reciprocal of the unit will be taken (when unitless, `1/1=`).
|
||||
The raw Nyquist Limit $n_q$ will be used to bound the unitless part of the output as $[-n_q, n_q]$
|
||||
|
||||
Raises:
|
||||
ValueError: If `self.scaling` is not linear, since the FT can only be performed on uniformly spaced data.
|
||||
"""
|
||||
if self.scaling is ScalingMode.Lin:
|
||||
nyquist_limit = self.steps / self.ideal_range
|
||||
|
||||
# Return New Bounds w/Nyquist Theorem
|
||||
## -> The Nyquist Limit describes "max repeated info per sample".
|
||||
## -> Information can still record "faster" than the Nyquist Limit.
|
||||
## -> It will just be either noise (best case), or banded artifacts.
|
||||
## -> This is called "aliasing", and it's best to try and filter it.
|
||||
## -> Sims generally "bin" yee cells, which sacrifices some NyqLim.
|
||||
return RangeFlow(
|
||||
start=-nyquist_limit,
|
||||
stop=nyquist_limit,
|
||||
scaling=self.scaling,
|
||||
unit=1 / self.unit if self.unit is not None else None,
|
||||
pre_fourier_ideal_midpoint=self.ideal_midpoint,
|
||||
)
|
||||
|
||||
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
|
||||
raise ValueError(msg)
|
||||
|
||||
@functools.cached_property
|
||||
def bound_inv_fourier_transform(self):
|
||||
r"""Treat this `RangeFlow` as an axis along which an inverse fourier transform is being performed, such that its Nyquist-limit bounds are transformed back into values along the original unit dimension.
|
||||
|
||||
See `self.bound_fourier_transform` for the theoretical concepts.
|
||||
|
||||
Notes:
|
||||
**The discrete inverse fourier transform always centers its output at $0$**.
|
||||
|
||||
Of course, it's entirely probable that the original signal was not centered at $0$.
|
||||
For this reason, when performing a Fourier transform, `self.bound_fourier_transform` sets a special variable, `self.pre_fourier_ideal_midpoint`.
|
||||
When set, it will retain the `self.ideal_midpoint` around which both `self.start` and `self.stop` should be centered after an inverse FT.
|
||||
|
||||
If `self.pre_fourier_ideal_midpoint` is set, then it will be used as the midpoint of the output's `start`/`stop`.
|
||||
Otherwise, $0$ will be used - in which case the user should themselves, manually, shift the output if needed.
|
||||
"""
|
||||
if self.scaling is ScalingMode.Lin:
|
||||
orig_ideal_range = self.steps / self.ideal_range
|
||||
|
||||
orig_start_centered = -orig_ideal_range
|
||||
orig_stop_centered = orig_ideal_range
|
||||
orig_ideal_midpoint = (
|
||||
self.pre_fourier_ideal_midpoint
|
||||
if self.pre_fourier_ideal_midpoint is not None
|
||||
else sp.S(0)
|
||||
)
|
||||
|
||||
# Return New Bounds w/Inverse of Nyquist Theorem
|
||||
return RangeFlow(
|
||||
start=-orig_start_centered + orig_ideal_midpoint,
|
||||
stop=orig_stop_centered + orig_ideal_midpoint,
|
||||
scaling=self.scaling,
|
||||
unit=1 / self.unit if self.unit is not None else None,
|
||||
)
|
||||
|
||||
msg = f'Cant fourier-transform an index array as a boundary, when the RangeArray has a non-linear bound {self.scaling}'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Exporters
|
||||
####################
|
||||
|
@ -237,15 +393,15 @@ class RangeFlow:
|
|||
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
|
||||
|
||||
Notes:
|
||||
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
||||
The ordering of the symbols is identical to `self.sorted_symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
||||
|
||||
Returns:
|
||||
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
|
||||
A function that generates a 1D numerical array equivalent to the range represented in this `RangeFlow`.
|
||||
"""
|
||||
# Compile JAX Functions for Start/End Expressions
|
||||
## -> FYI, JAX-in-JAX works perfectly fine.
|
||||
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
|
||||
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
|
||||
start_jax = sp.lambdify(self.sorted_sp_symbols, self.start, 'jax')
|
||||
stop_jax = sp.lambdify(self.sorted_sp_symbols, self.stop, 'jax')
|
||||
|
||||
# Compile ArrayGen Function
|
||||
def gen_array(
|
||||
|
@ -256,54 +412,80 @@ class RangeFlow:
|
|||
# Return ArrayGen Function
|
||||
return gen_array
|
||||
|
||||
@functools.cached_property
|
||||
def as_lazy_func(self) -> FuncFlow:
|
||||
"""Creates a `FuncFlow` using the output of `self.as_func`.
|
||||
|
||||
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
|
||||
|
||||
Notes:
|
||||
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
|
||||
|
||||
Returns:
|
||||
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
|
||||
"""
|
||||
return FuncFlow(
|
||||
func=self.as_func,
|
||||
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Realization
|
||||
####################
|
||||
def realize_symbols(
|
||||
self,
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> dict[sp.Symbol, spux.ScalarUnitlessComplexExpr]:
|
||||
"""Realize **all** input symbols to the `RangeFlow`.
|
||||
|
||||
Parameters:
|
||||
symbol_values: A scalar, unitless, complex `sympy` expression for each symbol defined in `self.symbols`.
|
||||
|
||||
Returns:
|
||||
A dictionary directly usable in expression substitutions using `sp.Basic.subs()`.
|
||||
"""
|
||||
if self.symbols == set(symbol_values.keys()):
|
||||
realized_syms = {}
|
||||
for sym in self.sorted_symbols:
|
||||
sym_value = symbol_values[sym]
|
||||
|
||||
# Sympy Expression
|
||||
## -> We need to conform the expression to the SimSymbol.
|
||||
## -> Mainly, this is
|
||||
if (
|
||||
isinstance(sym_value, spux.SympyType)
|
||||
and not isinstance(sym_value, sp.MatrixBase)
|
||||
and not spux.uses_units(sym_value)
|
||||
):
|
||||
v = sym.conform(sym_value)
|
||||
else:
|
||||
msg = f'RangeFlow: No realization support for symbolic value {sym_value} (type={type(sym_value)})'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
realized_syms |= {sym: v}
|
||||
|
||||
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||
raise ValueError(msg)
|
||||
|
||||
def realize_start(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> int | float | complex:
|
||||
"""Realize the start-bound by inserting particular values for each symbol."""
|
||||
return spux.sympy_to_python(
|
||||
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||
)
|
||||
realized_symbols = self.realize_symbols(symbol_values)
|
||||
return spux.sympy_to_python(self.start.subs(realized_symbols))
|
||||
|
||||
def realize_stop(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> int | float | complex:
|
||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||
return spux.sympy_to_python(
|
||||
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||
)
|
||||
realized_symbols = self.realize_symbols(symbol_values)
|
||||
return spux.sympy_to_python(self.stop.subs(realized_symbols))
|
||||
|
||||
def realize_step_size(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> int | float | complex:
|
||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||
if self.scaling is not ScalingMode.Lin:
|
||||
raise NotImplementedError('Non-linear scaling mode not yet suported')
|
||||
msg = 'Non-linear scaling mode not yet suported'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
|
||||
raw_step_size = (
|
||||
self.realize_stop(symbol_values) - self.realize_start(symbol_values) + 1
|
||||
) / self.steps
|
||||
|
||||
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
|
||||
return int(raw_step_size)
|
||||
|
@ -311,7 +493,9 @@ class RangeFlow:
|
|||
|
||||
def realize(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
) -> ArrayFlow:
|
||||
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
|
||||
|
||||
|
@ -324,16 +508,34 @@ class RangeFlow:
|
|||
## TODO: Check symbol values for coverage.
|
||||
|
||||
return ArrayFlow(
|
||||
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
|
||||
values=self.as_func(
|
||||
*[
|
||||
spux.scale_to_unit_system(symbol_values[sym])
|
||||
for sym in self.sorted_symbols
|
||||
]
|
||||
),
|
||||
unit=self.unit,
|
||||
is_sorted=True,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def realize_array(self) -> ArrayFlow:
|
||||
"""Standardized access to `self.realize()` when there are no symbols."""
|
||||
"""Standardized access to `self.realize()` when there are no symbols.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are symbols defined in `self.symbols`.
|
||||
"""
|
||||
if not self.symbols:
|
||||
return self.realize()
|
||||
|
||||
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
|
||||
"""Alias for `realize_array.values`."""
|
||||
return self.realize_array.values
|
||||
|
||||
def __getitem__(self, subscript: slice):
|
||||
"""Implement indexing and slicing in a sane way.
|
||||
|
||||
|
@ -379,12 +581,6 @@ class RangeFlow:
|
|||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Corrected unit to %s',
|
||||
self,
|
||||
corrected_unit,
|
||||
)
|
||||
return RangeFlow(
|
||||
start=self.start,
|
||||
stop=self.stop,
|
||||
|
@ -394,9 +590,6 @@ class RangeFlow:
|
|||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
||||
"""Replaces the unit, **with** rescaling of the bounds.
|
||||
|
||||
|
@ -410,11 +603,6 @@ class RangeFlow:
|
|||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to unit %s',
|
||||
self,
|
||||
unit,
|
||||
)
|
||||
return RangeFlow(
|
||||
start=spux.scale_to_unit(self.start * self.unit, unit),
|
||||
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
||||
|
@ -423,11 +611,18 @@ class RangeFlow:
|
|||
unit=unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
return RangeFlow(
|
||||
start=self.start * unit,
|
||||
stop=self.stop * unit,
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||
def rescale_to_unit_system(
|
||||
self, unit_system: spux.UnitSystem | None = None
|
||||
) -> typ.Self:
|
||||
"""Replaces the units, **with** rescaling of the bounds.
|
||||
|
||||
Parameters:
|
||||
|
@ -439,28 +634,11 @@ class RangeFlow:
|
|||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to new unit system (new unit = %s)',
|
||||
self,
|
||||
unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
)
|
||||
return RangeFlow(
|
||||
start=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
stop=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
start=spux.scale_to_unit_system(self.start * self.unit, unit_system),
|
||||
stop=spux.scale_to_unit_system(self.stop * self.unit, unit_system),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
unit=spux.convert_to_unit_system(self.unit, unit_system),
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
|
|
@ -17,15 +17,19 @@
|
|||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
|
||||
import jaxtyping as jtyp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .expr_info import ExprInfo
|
||||
from .flow_kinds import FlowKind
|
||||
from .lazy_range import RangeFlow
|
||||
|
||||
# from .info import InfoFlow
|
||||
|
||||
|
@ -34,13 +38,22 @@ log = logger.get(__name__)
|
|||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class ParamsFlow:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
|
||||
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
||||
|
||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||
|
||||
####################
|
||||
# - Symbols
|
||||
####################
|
||||
@functools.cached_property
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
||||
Returns:
|
||||
|
@ -48,52 +61,179 @@ class ParamsFlow:
|
|||
"""
|
||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||
|
||||
@functools.cached_property
|
||||
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
|
||||
"""Computes `sympy` symbols from `self.sorted_symbols`.
|
||||
|
||||
When the output is shaped, a single shaped symbol (`sp.MatrixSymbol`) is used to represent the symbolic name and shaping.
|
||||
This choice is made due to `MatrixSymbol`'s compatibility with `.lambdify` JIT.
|
||||
|
||||
Returns:
|
||||
All symbols valid for use in the expression.
|
||||
"""
|
||||
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||
|
||||
####################
|
||||
# - JIT'ed Callables for Numerical Function Arguments
|
||||
####################
|
||||
def func_args_n(
|
||||
self, target_syms: list[sim_symbols.SimSymbol]
|
||||
) -> list[
|
||||
typ.Callable[
|
||||
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||
]
|
||||
]:
|
||||
"""Callable functions for evaluating each `self.func_args` entry numerically.
|
||||
|
||||
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `target_syms`.
|
||||
|
||||
Notes:
|
||||
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
|
||||
|
||||
This ensures conformance to the `SimSymbol` properties (like units), as well as adherance to a numerical type identity compatible with `sp.lambdify()`.
|
||||
|
||||
Parameters:
|
||||
target_syms: `SimSymbol`s describing how a particular `ParamsFlow` function argument should be scaled when performing a purely numerical insertion.
|
||||
"""
|
||||
return [
|
||||
sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
target_sym.conform(func_arg, strip_unit=True),
|
||||
'jax',
|
||||
)
|
||||
for func_arg, target_sym in zip(self.func_args, target_syms, strict=True)
|
||||
]
|
||||
|
||||
def func_kwargs_n(
|
||||
self, target_syms: dict[str, sim_symbols.SimSymbol]
|
||||
) -> dict[
|
||||
str,
|
||||
typ.Callable[
|
||||
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
|
||||
int | float | complex | jtyp.Inexact[jtyp.Array, '...'],
|
||||
],
|
||||
]:
|
||||
"""Callable functions for evaluating each `self.func_kwargs` entry numerically.
|
||||
|
||||
The arguments of each function **must** be pre-treated using `SimSymbol.scale()`.
|
||||
This ensures conformance to the `SimSymbol` properties, as well as adherance to a numerical type identity compatible with `sp.lambdify()`
|
||||
"""
|
||||
return {
|
||||
func_arg_key: sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
target_syms[func_arg_key].scale(func_arg),
|
||||
'jax',
|
||||
)
|
||||
for func_arg_key, func_arg in self.func_kwargs.items()
|
||||
}
|
||||
|
||||
####################
|
||||
# - Realization
|
||||
####################
|
||||
def realize_symbols(
|
||||
self,
|
||||
symbol_values: dict[
|
||||
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
|
||||
] = MappingProxyType({}),
|
||||
) -> dict[
|
||||
sp.Symbol,
|
||||
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...'] :,
|
||||
]:
|
||||
"""Fully realize all symbols by assigning them a value.
|
||||
|
||||
Three kinds of values for `symbol_values` are supported, fundamentally:
|
||||
|
||||
- **Sympy Expression**: When the value is a sympy expression with units, the unit of the `SimSymbol` key which unit the value if converted to.
|
||||
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||
- **Range**: When the value is a `RangeFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
|
||||
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||
- **Array**: When the value is an `ArrayFlow`, units are converted to the `SimSymbol`'s unit using `.rescale_to_unit()`.
|
||||
If the `SimSymbol`'s unit is `None`, then the value is left as-is.
|
||||
|
||||
Returns:
|
||||
A dictionary almost with `.subs()`, other than `jax` arrays.
|
||||
"""
|
||||
if set(self.symbols) == set(symbol_values.keys()):
|
||||
realized_syms = {}
|
||||
for sym in self.sorted_symbols:
|
||||
sym_value = symbol_values[sym]
|
||||
|
||||
if isinstance(sym_value, spux.SympyType):
|
||||
v = sym.scale(sym_value)
|
||||
|
||||
elif isinstance(sym_value, ArrayFlow | RangeFlow):
|
||||
v = sym_value.rescale_to_unit(sym.unit).values
|
||||
## NOTE: RangeFlow must not be symbolic.
|
||||
|
||||
else:
|
||||
msg = f'No support for symbolic value {sym_value} (type={type(sym_value)})'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
realized_syms |= {sym: v}
|
||||
|
||||
return realized_syms
|
||||
|
||||
msg = f'ParamsFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Realize Arguments
|
||||
####################
|
||||
def scaled_func_args(
|
||||
self,
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
target_syms: list[sim_symbols.SimSymbol] = (),
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
):
|
||||
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
|
||||
) -> list[
|
||||
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
|
||||
]:
|
||||
"""Realize correctly conformed numerical arguments for `self.func_args`.
|
||||
|
||||
For all `arg`s in `self.func_args`, the following operations are performed.
|
||||
Because we allow symbols to be used in `self.func_args`, producing a numerical value that can be passed directly to a `FuncFlow` becomes a two-step process:
|
||||
|
||||
Notes:
|
||||
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
|
||||
1. Conform Symbols: Arbitrary `sympy` expressions passed as `symbol_values` must first be conformed to match the ex. units of `SimSymbol`s found in `self.symbols`, before they can be used.
|
||||
|
||||
2. Conform Function Arguments: Arbitrary `sympy` expressions encoded in `self.func_args` must, **after** inserting the conformed numerical symbols, themselves be conformed to the expected ex. units of the function that they are to be used within.
|
||||
**`ParamsFlow` doesn't contain information about the `SimSymbol`s that `self.func_args` are expected to conform to** (on purpose).
|
||||
Therefore, the user is required to pass a `target_syms` with identical length to `self.func_args`, describing the `SimSymbol`s to conform the function arguments to.
|
||||
|
||||
Our implementation attempts to utilize simple, powerful primitives to accomplish this in roughly three steps:
|
||||
|
||||
1. **Realize Symbols**: Particular passed symbolic values `symbol_values`, which are arbitrary `sympy` expressions, are conformed to the definitions in `self.symbols` (ex. to match units), then cast to numerical values (pure Python / jax array).
|
||||
|
||||
2. **Lazy Function Arguments**: Stored function arguments `self.func_args`, which are arbitrary `sympy` expressions, are conformed to the definitions in `target_syms` (ex. to match units), then cast to numerical values (pure Python / jax array).
|
||||
_Technically, this happens as part of `self.func_args_n`._
|
||||
|
||||
3. **Numerical Evaluation**: The numerical values for each symbol are passed as parameters to each (callable) element of `self.func_args_n`, which produces a correct numerical value for each function argument.
|
||||
|
||||
Parameters:
|
||||
target_syms: `SimSymbol`s describing how the function arguments returned by this method are intended to be used.
|
||||
**Generally**, the parallel `FuncFlow.func_args` should be inserted here, and guarantees correct results when this output is inserted into `FuncFlow.func(...)`.
|
||||
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `target_syms`).
|
||||
"""
|
||||
if not all(sym in self.symbols for sym in symbol_values):
|
||||
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
||||
raise ValueError(msg)
|
||||
|
||||
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
|
||||
realized_symbols = list(self.realize_symbols(symbol_values).values())
|
||||
return [
|
||||
(
|
||||
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
|
||||
if arg not in symbol_values
|
||||
else symbol_values[arg]
|
||||
)
|
||||
for arg in self.func_args
|
||||
func_arg_n(*realized_symbols)
|
||||
for func_arg_n in self.func_args_n(target_syms)
|
||||
]
|
||||
|
||||
def scaled_func_kwargs(
|
||||
self,
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
target_syms: list[sim_symbols.SimSymbol] = (),
|
||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||
):
|
||||
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
||||
if not all(sym in self.symbols for sym in symbol_values):
|
||||
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
||||
raise ValueError(msg)
|
||||
) -> dict[
|
||||
str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
|
||||
]:
|
||||
"""Realize correctly conformed numerical arguments for `self.func_kwargs`.
|
||||
|
||||
Other than the `dict[str, ...]` key, the semantics are identical to `self.scaled_func_args()`.
|
||||
"""
|
||||
realized_symbols = self.realize_symbols(symbol_values)
|
||||
return {
|
||||
arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True)
|
||||
if arg not in symbol_values
|
||||
else symbol_values[arg]
|
||||
for arg_name, arg in self.func_kwargs.items()
|
||||
func_arg_name: func_arg_n(**realized_symbols)
|
||||
for func_arg_name, func_arg_n in self.func_kwargs_n(target_syms).items()
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -129,8 +269,8 @@ class ParamsFlow:
|
|||
####################
|
||||
# - Generate ExprSocketDef
|
||||
####################
|
||||
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
|
||||
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
|
||||
def sym_expr_infos(self, use_range: bool = False) -> dict[str, ExprInfo]:
|
||||
"""Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
|
||||
|
||||
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
|
||||
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
|
||||
|
@ -151,35 +291,22 @@ class ParamsFlow:
|
|||
|
||||
The `ExprInfo`s can be directly defererenced `**expr_info`)
|
||||
"""
|
||||
for sim_sym in self.sorted_symbols:
|
||||
if use_range and sim_sym.mathtype is spux.MathType.Complex:
|
||||
for sym in self.sorted_symbols:
|
||||
if use_range and sym.mathtype is spux.MathType.Complex:
|
||||
msg = 'No support for complex range in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
|
||||
if use_range and (sym.rows > 1 or sym.cols > 1):
|
||||
msg = 'No support for non-scalar elements of range in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
if sim_sym.rows > 3 or sim_sym.cols > 1:
|
||||
if sym.rows > 3 or sym.cols > 1:
|
||||
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return {
|
||||
sim_sym.name: {
|
||||
# Declare Kind/Size
|
||||
## -> Kind: Value prevents user-alteration of config.
|
||||
## -> Size: Always scalar, since symbols are scalar (for now).
|
||||
sym.name: {
|
||||
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
|
||||
'size': spux.NumberSize1D.Scalar,
|
||||
# Declare MathType/PhysicalType
|
||||
## -> MathType: Lookup symbol name in info dimensions.
|
||||
## -> PhysicalType: Same.
|
||||
'mathtype': self.dims[sim_sym].mathtype,
|
||||
'physical_type': self.dims[sim_sym].physical_type,
|
||||
# TODO: Default Value
|
||||
# FlowKind.Value: Default Value
|
||||
#'default_value':
|
||||
# FlowKind.Range: Default Min/Max/Steps
|
||||
'default_min': sim_sym.domain.start,
|
||||
'default_max': sim_sym.domain.end,
|
||||
'default_steps': 50,
|
||||
}
|
||||
for sim_sym in self.sorted_symbols
|
||||
| sym.expr_info
|
||||
for sym in self.sorted_symbols
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import tidy3d as td
|
|||
|
||||
from blender_maxwell.contracts import BLEnumElement
|
||||
from blender_maxwell.services import tdcloud
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .flow_kinds.info import InfoFlow
|
||||
|
@ -482,6 +483,93 @@ class DataFileFormat(enum.StrEnum):
|
|||
## - When sidecars aren't found, the user would "fill in the blanks".
|
||||
## - ...Thus achieving the same result as if there were a sidecar.
|
||||
|
||||
####################
|
||||
# - Functions: DataFrame
|
||||
####################
|
||||
@staticmethod
|
||||
def to_df(
|
||||
data: jtyp.Shaped[jtyp.Array, 'x_size y_size'], info: InfoFlow
|
||||
) -> pl.DataFrame:
|
||||
"""Utility method to convert raw data to a `polars.DataFrame`, as guided by an `InfoFlow`.
|
||||
|
||||
Only works with 2D data (obviously).
|
||||
|
||||
Raises:
|
||||
ValueError: If the data has more than two dimensions, all `info` dimensions are not discrete/labelled, or the dimensionality of `info` doesn't match.
|
||||
"""
|
||||
if info.order > 2: # noqa: PLR2004
|
||||
msg = f'Data may not have more than two dimensions (info={info}, data.shape={data.shape})'
|
||||
raise ValueError(msg)
|
||||
|
||||
if any(info.has_idx_cont(dim) for dim in info.dims):
|
||||
msg = f'To convert data|info to a dataframe, no dimensions can have continuous indices (info={info})'
|
||||
raise ValueError(msg)
|
||||
|
||||
data_np = np.array(data)
|
||||
|
||||
MT = spux.MathType
|
||||
match (
|
||||
info.input_mathtypes,
|
||||
info.output.mathtype,
|
||||
info.output.rows,
|
||||
info.output.cols,
|
||||
):
|
||||
# (R,Z) -> Complex Scalar
|
||||
## -> Polars (also pandas) doesn't have a complex type.
|
||||
## -> Will be treated as (R, Z, 2) -> Real Scalar.
|
||||
case ((MT.Rational | MT.Real, MT.Integer), MT.Complex, 1, 1):
|
||||
row_dim = info.first_dim
|
||||
col_dim = info.last_dim
|
||||
|
||||
return pl.DataFrame(
|
||||
{row_dim.name: info.dims[row_dim]}
|
||||
| {
|
||||
col_label + postfix: re_im(data_np[:, col])
|
||||
for col, col_label in enumerate(info.dims[col_dim])
|
||||
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
|
||||
}
|
||||
)
|
||||
|
||||
# (R,Z) -> Scalar
|
||||
case ((MT.Rational | MT.Real, MT.Integer), _, 1, 1):
|
||||
row_dim = info.first_dim
|
||||
col_dim = info.last_dim
|
||||
|
||||
return pl.DataFrame(
|
||||
{row_dim.name: info.dims[row_dim]}
|
||||
| {
|
||||
col_label: data_np[:, col]
|
||||
for col, col_label in enumerate(info.dims[col_dim])
|
||||
}
|
||||
)
|
||||
|
||||
# (Z) -> Complex Vector/Covector
|
||||
case ((MT.Integer,), MT.Complex, r, c) if (r > 1 and c == 1) or (
|
||||
r == 1 and c > 1
|
||||
):
|
||||
col_dim = info.last_dim
|
||||
|
||||
return pl.DataFrame(
|
||||
{
|
||||
col_label + postfix: re_im(data_np[col, :])
|
||||
for col, col_label in enumerate(info.dims[col_dim])
|
||||
for postfix, re_im in [('_re', np.real), ('_im', np.imag)]
|
||||
}
|
||||
)
|
||||
|
||||
# (Z) -> Real Vector
|
||||
## -> Each integer index will be treated as a column index.
|
||||
## -> This will effectively transpose the data.
|
||||
case ((MT.Integer,), _, r, c) if (r > 1 and c == 1) or (r == 1 and c > 1):
|
||||
col_dim = info.last_dim
|
||||
|
||||
return pl.DataFrame(
|
||||
{
|
||||
col_label: data_np[col, :]
|
||||
for col, col_label in enumerate(info.dims[col_dim])
|
||||
}
|
||||
)
|
||||
|
||||
####################
|
||||
# - Functions: Saver
|
||||
####################
|
||||
|
@ -506,50 +594,7 @@ class DataFileFormat(enum.StrEnum):
|
|||
np.savetxt(path, data)
|
||||
|
||||
def save_csv(path, data, info):
|
||||
data_np = np.array(data)
|
||||
|
||||
# Extract Input Coordinates
|
||||
dim_columns = {
|
||||
dim.name: np.array(dim_idx.realize_array)
|
||||
for i, (dim, dim_idx) in enumerate(info.dims)
|
||||
} ## TODO: realize_array might not be defined on some index arrays
|
||||
|
||||
# Declare Function to Extract Output Values
|
||||
output_columns = {}
|
||||
|
||||
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
|
||||
nonlocal output_columns
|
||||
|
||||
# Complex: Split to Two Columns
|
||||
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
|
||||
if bool(np.any(np.iscomplex(data_col))):
|
||||
output_columns |= {
|
||||
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
|
||||
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
|
||||
}
|
||||
|
||||
# Else: Use Array Directly
|
||||
else:
|
||||
output_columns |= {
|
||||
f'{info.output.name}{output_idx_str}': data_col,
|
||||
}
|
||||
|
||||
## TODO: Maybe a check to ensure dtype!=object?
|
||||
|
||||
# Extract Output Values
|
||||
## -> 2D: Iterate over columns by-index.
|
||||
## -> 1D: Declare the array as the only column.
|
||||
if len(data_np.shape) == 2:
|
||||
for output_idx in data_np.shape[1]:
|
||||
declare_output_col(data_np[:, output_idx], output_idx, True)
|
||||
else:
|
||||
declare_output_col(data_np)
|
||||
|
||||
# Compute DataFrame & Write CSV
|
||||
df = pl.DataFrame(dim_columns | output_columns)
|
||||
|
||||
log.debug('Writing Polars DataFrame to CSV:')
|
||||
log.debug(df)
|
||||
df = self.to_df(data, info)
|
||||
df.write_csv(path)
|
||||
|
||||
def save_npy(path, data, info):
|
||||
|
|
|
@ -264,17 +264,22 @@ class ManagedBLImage(base.ManagedObj):
|
|||
# times = [time.perf_counter()]
|
||||
|
||||
# Compute Plot Dimensions
|
||||
aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||
self.gen_image_geometry(width_inches, height_inches, dpi)
|
||||
)
|
||||
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||
# self.gen_image_geometry(width_inches, height_inches, dpi)
|
||||
# )
|
||||
# times.append(['Image Geometry', time.perf_counter() - times[0]])
|
||||
# log.critical(
|
||||
# [aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px]
|
||||
# )
|
||||
|
||||
# Create MPL Figure, Axes, and Compute Figure Geometry
|
||||
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
|
||||
_width_inches, _height_inches, _dpi
|
||||
)
|
||||
# fig, canvas, ax = image_ops.mpl_fig_canvas_ax(
|
||||
# _width_inches, _height_inches, _dpi
|
||||
# )
|
||||
fig, canvas, ax = image_ops.mpl_fig_canvas_ax(width_inches, height_inches, dpi)
|
||||
# times.append(['MPL Fig Canvas Axis', time.perf_counter() - times[0]])
|
||||
|
||||
# fig.clear()
|
||||
ax.clear()
|
||||
# times.append(['Clear Axis', time.perf_counter() - times[0]])
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ class FilterOperation(enum.StrEnum):
|
|||
"""
|
||||
|
||||
# Slice
|
||||
Slice = enum.auto()
|
||||
SliceIdx = enum.auto()
|
||||
|
||||
# Pin
|
||||
|
@ -53,9 +54,8 @@ class FilterOperation(enum.StrEnum):
|
|||
Pin = enum.auto()
|
||||
PinIdx = enum.auto()
|
||||
|
||||
# Reinterpret
|
||||
# Dimension
|
||||
Swap = enum.auto()
|
||||
SetDim = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
|
@ -65,14 +65,14 @@ class FilterOperation(enum.StrEnum):
|
|||
FO = FilterOperation
|
||||
return {
|
||||
# Slice
|
||||
FO.SliceIdx: 'a[...]',
|
||||
FO.Slice: '=a[i:j]',
|
||||
FO.SliceIdx: '≈a[v₁:v₂]',
|
||||
# Pin
|
||||
FO.PinLen1: 'pinₐ =1',
|
||||
FO.PinLen1: 'pinₐ',
|
||||
FO.Pin: 'pinₐ ≈v',
|
||||
FO.PinIdx: 'pinₐ =a[v]',
|
||||
FO.PinIdx: 'pinₐ =i',
|
||||
# Reinterpret
|
||||
FO.Swap: 'a₁ ↔ a₂',
|
||||
FO.SetDim: 'setₐ =v',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
|
@ -118,11 +118,6 @@ class FilterOperation(enum.StrEnum):
|
|||
if len(info.dims) >= 2: # noqa: PLR2004
|
||||
operations.append(FO.Swap)
|
||||
|
||||
## SetDim
|
||||
## -> There must be a dimension to correct.
|
||||
if info.dims:
|
||||
operations.append(FO.SetDim)
|
||||
|
||||
return operations
|
||||
|
||||
####################
|
||||
|
@ -145,6 +140,7 @@ class FilterOperation(enum.StrEnum):
|
|||
FO = FilterOperation
|
||||
return {
|
||||
# Slice
|
||||
FO.Slice: 1,
|
||||
FO.SliceIdx: 1,
|
||||
# Pin
|
||||
FO.PinLen1: 1,
|
||||
|
@ -152,40 +148,35 @@ class FilterOperation(enum.StrEnum):
|
|||
FO.PinIdx: 1,
|
||||
# Reinterpret
|
||||
FO.Swap: 2,
|
||||
FO.SetDim: 1,
|
||||
}[self]
|
||||
|
||||
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||
FO = FilterOperation
|
||||
match self:
|
||||
case FO.SliceIdx | FO.Swap:
|
||||
return info.dims
|
||||
# Slice
|
||||
case FO.Slice:
|
||||
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
|
||||
|
||||
# PinLen1: Only allow dimensions with length=1.
|
||||
case FO.SliceIdx:
|
||||
return [dim for dim in info.dims if not dim.has_idx_labels(dim)]
|
||||
|
||||
# Pin
|
||||
case FO.PinLen1:
|
||||
return [
|
||||
dim
|
||||
for dim, dim_idx in info.dims.items()
|
||||
if dim_idx is not None and len(dim_idx) == 1
|
||||
if not info.has_idx_cont(dim) and len(dim_idx) == 1
|
||||
]
|
||||
|
||||
# Pin: Only allow dimensions with discrete index.
|
||||
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
|
||||
case FO.Pin | FO.PinIdx:
|
||||
return [
|
||||
dim
|
||||
for dim, dim_idx in info.dims
|
||||
if dim_idx is not None and len(dim_idx) > 0
|
||||
]
|
||||
case FO.Pin:
|
||||
return info.dims
|
||||
|
||||
case FO.SetDim:
|
||||
return [
|
||||
dim
|
||||
for dim, dim_idx in info.dims
|
||||
if dim_idx is not None
|
||||
and not isinstance(dim_idx, list)
|
||||
and dim_idx.mathtype == spux.MathType.Integer
|
||||
]
|
||||
case FO.PinIdx:
|
||||
return [dim for dim in info.dims if not info.has_idx_cont(dim)]
|
||||
|
||||
# Dimension
|
||||
case FO.Swap:
|
||||
return info.dims
|
||||
|
||||
return []
|
||||
|
||||
|
@ -193,9 +184,14 @@ class FilterOperation(enum.StrEnum):
|
|||
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
|
||||
) -> bool:
|
||||
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
|
||||
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or (
|
||||
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info)
|
||||
)
|
||||
if self.num_dim_inputs == 1:
|
||||
return dim_0 in self.valid_dims(info)
|
||||
|
||||
if self.num_dim_inputs == 2: # noqa: PLR2004
|
||||
valid_dims = self.valid_dims(info)
|
||||
return dim_0 in valid_dims and dim_1 in valid_dims
|
||||
|
||||
return False
|
||||
|
||||
####################
|
||||
# - UI
|
||||
|
@ -209,6 +205,9 @@ class FilterOperation(enum.StrEnum):
|
|||
FO = FilterOperation
|
||||
return {
|
||||
# Pin
|
||||
FO.Slice: lambda expr: jlax.slice_in_dim(
|
||||
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
||||
),
|
||||
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
|
||||
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
||||
),
|
||||
|
@ -216,9 +215,8 @@ class FilterOperation(enum.StrEnum):
|
|||
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
|
||||
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||
# Reinterpret
|
||||
# Dimension
|
||||
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
|
||||
FO.SetDim: lambda expr: expr,
|
||||
}[self]
|
||||
|
||||
def transform_info(
|
||||
|
@ -228,10 +226,10 @@ class FilterOperation(enum.StrEnum):
|
|||
dim_1: sim_symbols.SimSymbol,
|
||||
pin_idx: int | None = None,
|
||||
slice_tuple: tuple[int, int, int] | None = None,
|
||||
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
|
||||
):
|
||||
FO = FilterOperation
|
||||
return {
|
||||
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||
# Pin
|
||||
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||
|
@ -239,7 +237,6 @@ class FilterOperation(enum.StrEnum):
|
|||
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||
# Reinterpret
|
||||
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
||||
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
|
||||
}[self]()
|
||||
|
||||
|
||||
|
@ -330,8 +327,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_info is not None and self.operation is not None:
|
||||
return [
|
||||
(dim_name, dim_name, dim_name, '', i)
|
||||
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info))
|
||||
(dim.name, dim.name_pretty, dim.name, '', i)
|
||||
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
|
||||
]
|
||||
return []
|
||||
|
||||
|
@ -380,8 +377,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
# Reinterpret
|
||||
case FO.Swap:
|
||||
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
|
||||
case FO.SetDim:
|
||||
return f'Filter: Set [{self.active_dim_0}]'
|
||||
|
||||
case _:
|
||||
return self.bl_label
|
||||
|
@ -480,30 +475,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
}
|
||||
|
||||
# Loose Sockets: Set Dim
|
||||
## -> The user must provide a (ℤ) -> ℝ array.
|
||||
## -> It must be of identical length to the replaced axis.
|
||||
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
|
||||
dim = dim_0
|
||||
current_bl_socket = self.loose_input_sockets.get('Dim')
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.active_kind != ct.FlowKind.Func
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
or current_bl_socket.mathtype != dim.mathtype
|
||||
or current_bl_socket.physical_type != dim.physical_type
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
'Dim': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Func,
|
||||
physical_type=dim.physical_type,
|
||||
mathtype=dim.mathtype,
|
||||
default_unit=dim.unit,
|
||||
show_func_ui=False,
|
||||
show_info_columns=True,
|
||||
)
|
||||
}
|
||||
|
||||
# No Loose Value: Remove Input Sockets
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
@ -570,60 +541,11 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
|
||||
# Dim (Op.SetDim)
|
||||
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
|
||||
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
|
||||
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
|
||||
|
||||
has_dim_func = not ct.FlowSignal.check(dim_func)
|
||||
has_dim_params = not ct.FlowSignal.check(dim_params)
|
||||
has_dim_info = not ct.FlowSignal.check(dim_info)
|
||||
|
||||
# Dimension(s)
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
slice_tuple = props['slice_tuple']
|
||||
if has_info and operation is not None:
|
||||
# Set Dimension: Retrieve Array
|
||||
if props['operation'] is FilterOperation.SetDim:
|
||||
new_dim = (
|
||||
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
|
||||
)
|
||||
|
||||
if (
|
||||
dim_0 is not None
|
||||
and new_dim is not None
|
||||
and has_dim_info
|
||||
and has_dim_params
|
||||
# Check New Dimension Index Array Sizing
|
||||
and len(dim_info.dims) == 1
|
||||
and dim_info.output.rows == 1
|
||||
and dim_info.output.cols == 1
|
||||
# Check Lack of Params Symbols
|
||||
and not dim_params.symbols
|
||||
# Check Expr Dim | New Dim Compatibility
|
||||
and info.has_idx_discrete(dim_0)
|
||||
and dim_info.has_idx_discrete(new_dim)
|
||||
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
|
||||
):
|
||||
# Retrieve Dimension Coordinate Array
|
||||
## -> It must be strictly compatible.
|
||||
values = dim_func.realize(dim_params, spux.UNITS_SI)
|
||||
|
||||
# Transform Info w/Corrected Dimension
|
||||
## -> The existing dimension will be replaced.
|
||||
new_dim_idx = ct.ArrayFlow(
|
||||
values=values,
|
||||
unit=spux.convert_to_unit_system(
|
||||
dim_info.output.unit, spux.UNITS_SI
|
||||
),
|
||||
).rescale_to_unit(dim_info.output.unit)
|
||||
|
||||
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
|
||||
return operation.transform_info(
|
||||
info, dim_0, dim_1, replaced_dim=replaced_dim
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
|
|
@ -496,7 +496,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.info is not None:
|
||||
if self.expr_info is not None:
|
||||
return [
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
|
||||
|
|
|
@ -20,8 +20,11 @@ import typing as typ
|
|||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
import sympy.physics.quantum as spq
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
|
@ -37,37 +40,47 @@ class BinaryOperation(enum.StrEnum):
|
|||
"""Valid operations for the `OperateMathNode`.
|
||||
|
||||
Attributes:
|
||||
Add: Addition w/broadcasting.
|
||||
Sub: Subtraction w/broadcasting.
|
||||
Mul: Hadamard-product multiplication.
|
||||
Div: Hadamard-product based division.
|
||||
Pow: Elementwise expontiation.
|
||||
Atan2: Quadrant-respecting arctangent variant.
|
||||
VecVecDot: Dot product for vectors.
|
||||
Cross: Cross product.
|
||||
MatVecDot: Matrix-Vector dot product.
|
||||
Mul: Scalar multiplication.
|
||||
Div: Scalar division.
|
||||
Pow: Scalar exponentiation.
|
||||
Add: Elementwise addition.
|
||||
Sub: Elementwise subtraction.
|
||||
HadamMul: Elementwise multiplication (hadamard product).
|
||||
HadamPow: Principled shape-aware exponentiation (hadamard power).
|
||||
Atan2: Quadrant-respecting 2D arctangent.
|
||||
VecVecDot: Dot product for identically shaped vectors w/transpose.
|
||||
Cross: Cross product between identically shaped 3D vectors.
|
||||
VecVecOuter: Vector-vector outer product.
|
||||
LinSolve: Solve a linear system.
|
||||
LsqSolve: Minimize error of an underdetermined linear system.
|
||||
MatMatDot: Matrix-Matrix dot product.
|
||||
VecMatOuter: Vector-matrix outer product.
|
||||
MatMatDot: Matrix-matrix dot product.
|
||||
"""
|
||||
|
||||
# Number | Number
|
||||
Add = enum.auto()
|
||||
Sub = enum.auto()
|
||||
Mul = enum.auto()
|
||||
Div = enum.auto()
|
||||
Pow = enum.auto()
|
||||
|
||||
# Elements | Elements
|
||||
Add = enum.auto()
|
||||
Sub = enum.auto()
|
||||
HadamMul = enum.auto()
|
||||
# HadamPow = enum.auto() ## TODO: Sympy's HadamardPower is problematic.
|
||||
Atan2 = enum.auto()
|
||||
|
||||
# Vector | Vector
|
||||
VecVecDot = enum.auto()
|
||||
Cross = enum.auto()
|
||||
VecVecOuter = enum.auto()
|
||||
|
||||
# Matrix | Vector
|
||||
MatVecDot = enum.auto()
|
||||
LinSolve = enum.auto()
|
||||
LsqSolve = enum.auto()
|
||||
|
||||
# Vector | Matrix
|
||||
VecMatOuter = enum.auto()
|
||||
|
||||
# Matrix | Matrix
|
||||
MatMatDot = enum.auto()
|
||||
|
||||
|
@ -79,19 +92,24 @@ class BinaryOperation(enum.StrEnum):
|
|||
BO = BinaryOperation
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Mul: 'ℓ · r',
|
||||
BO.Div: 'ℓ / r',
|
||||
BO.Pow: 'ℓ ^ r',
|
||||
# Elements | Elements
|
||||
BO.Add: 'ℓ + r',
|
||||
BO.Sub: 'ℓ - r',
|
||||
BO.Mul: 'ℓ ⊙ r', ## Notation for Hadamard Product
|
||||
BO.Div: 'ℓ / r',
|
||||
BO.Pow: 'ℓʳ',
|
||||
BO.Atan2: 'atan2(ℓ,r)',
|
||||
BO.HadamMul: '𝐋 ⊙ 𝐑',
|
||||
# BO.HadamPow: '𝐥 ⊙^ 𝐫',
|
||||
BO.Atan2: 'atan2(ℓ:x, r:y)',
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: '𝐥 · 𝐫',
|
||||
BO.Cross: 'cross(L,R)',
|
||||
BO.Cross: 'cross(𝐥,𝐫)',
|
||||
BO.VecVecOuter: '𝐥 ⊗ 𝐫',
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: '𝐋 · 𝐫',
|
||||
BO.LinSolve: '𝐋 ∖ 𝐫',
|
||||
BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
|
||||
# Vector | Matrix
|
||||
BO.VecMatOuter: '𝐋 ⊗ 𝐫',
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: '𝐋 · 𝐑',
|
||||
}[value]
|
||||
|
@ -118,56 +136,104 @@ class BinaryOperation(enum.StrEnum):
|
|||
"""Deduce valid binary operations from the shapes of the inputs."""
|
||||
BO = BinaryOperation
|
||||
|
||||
ops_number_number = [
|
||||
ops_el_el = [
|
||||
BO.Add,
|
||||
BO.Sub,
|
||||
BO.HadamMul,
|
||||
# BO.HadamPow,
|
||||
]
|
||||
|
||||
outl = info_l.output
|
||||
outr = info_r.output
|
||||
match (outl.shape_len, outr.shape_len):
|
||||
# match (ol.shape_len, info_r.output.shape_len):
|
||||
# Number | *
|
||||
## Number | Number
|
||||
case (0, 0):
|
||||
ops = [
|
||||
BO.Add,
|
||||
BO.Sub,
|
||||
BO.Mul,
|
||||
BO.Div,
|
||||
BO.Pow,
|
||||
BO.Atan2,
|
||||
]
|
||||
|
||||
match (info_l.output_shape_len, info_r.output_shape_len):
|
||||
# Number | *
|
||||
## Number | Number
|
||||
case (0, 0):
|
||||
return ops_number_number
|
||||
if (
|
||||
info_l.output.physical_type == spux.PhysicalType.Length
|
||||
and info_l.output.unit == info_r.output.unit
|
||||
):
|
||||
ops += [BO.Atan2]
|
||||
return ops
|
||||
|
||||
## Number | Vector
|
||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
||||
case (0, 1):
|
||||
return ops_number_number
|
||||
return [BO.Mul] # , BO.HadamPow]
|
||||
|
||||
## Number | Matrix
|
||||
## -> Broadcasting allows Number|Number ops to work as-is.
|
||||
case (0, 2):
|
||||
return ops_number_number
|
||||
return [BO.Mul] # , BO.HadamPow]
|
||||
|
||||
# Vector | *
|
||||
## Vector | Number
|
||||
case (1, 0):
|
||||
return ops_number_number
|
||||
return [BO.Mul] # , BO.HadamPow]
|
||||
|
||||
## Vector | Number
|
||||
## Vector | Vector
|
||||
case (1, 1):
|
||||
return [*ops_number_number, BO.VecVecDot, BO.Cross]
|
||||
ops = []
|
||||
|
||||
# Vector | Vector
|
||||
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
|
||||
if outl.rows > outl.cols and outr.rows > outr.cols:
|
||||
ops += [BO.VecVecDot, BO.VecVecOuter]
|
||||
|
||||
# Covector | Vector
|
||||
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
|
||||
if outl.rows < outl.cols and outr.rows > outr.cols:
|
||||
ops += [BO.MatMatDot, BO.VecVecOuter]
|
||||
|
||||
# Vector | Covector
|
||||
## -> Dot: Directly use matrix-matrix dot, as it's now correct.
|
||||
## -> These are both the same operation, in this case.
|
||||
if outl.rows > outl.cols and outr.rows < outr.cols:
|
||||
ops += [BO.MatMatDot, BO.VecVecOuter]
|
||||
|
||||
# Covector | Covector
|
||||
## -> Dot: Convenience; utilize special vec-vec dot w/transp.
|
||||
if outl.rows < outl.cols and outr.rows < outr.cols:
|
||||
ops += [BO.VecVecDot, BO.VecVecOuter]
|
||||
|
||||
# Cross Product
|
||||
## -> Enforce that both are 3x1 or 1x3.
|
||||
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
|
||||
if (outl.rows == 3 and outr.rows == 3) or (
|
||||
outl.cols == 3 and outl.cols == 3
|
||||
):
|
||||
ops += [BO.Cross]
|
||||
|
||||
return ops
|
||||
|
||||
## Vector | Matrix
|
||||
case (1, 2):
|
||||
return []
|
||||
return [BO.VecMatOuter]
|
||||
|
||||
# Matrix | *
|
||||
## Matrix | Number
|
||||
case (2, 0):
|
||||
return [*ops_number_number, BO.MatMatDot]
|
||||
return [BO.Mul] # , BO.HadamPow]
|
||||
|
||||
## Matrix | Vector
|
||||
case (2, 1):
|
||||
return [BO.MatVecDot, BO.LinSolve, BO.LsqSolve]
|
||||
prepend_ops = []
|
||||
|
||||
# Mat-Vec Dot: Enforce RHS Column Vector
|
||||
if outr.rows > outl.cols:
|
||||
prepend_ops += [BO.MatMatDot]
|
||||
|
||||
return [*ops, BO.LinSolve, BO.LsqSolve] # , BO.HadamPow]
|
||||
|
||||
## Matrix | Matrix
|
||||
case (2, 2):
|
||||
return [*ops_number_number, BO.MatMatDot]
|
||||
return [*ops_el_el, BO.MatMatDot]
|
||||
|
||||
return []
|
||||
|
||||
|
@ -182,34 +248,86 @@ class BinaryOperation(enum.StrEnum):
|
|||
## TODO: Make this compatible with sp.Matrix inputs
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||
# Elements | Elements
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.HadamMul: lambda exprs: sp.hadamard_product(exprs[0], exprs[1]),
|
||||
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
|
||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
|
||||
BO.Cross: lambda exprs: exprs[0].cross(exprs[1]),
|
||||
BO.VecVecOuter: lambda exprs: exprs[0] @ exprs[1].T,
|
||||
# Matrix | Vector
|
||||
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
|
||||
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
|
||||
# Vector | Matrix
|
||||
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def unit_func(self):
|
||||
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
|
||||
BO = BinaryOperation
|
||||
|
||||
## TODO: Make this compatible with sp.Matrix inputs
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Mul: BO.Mul.sp_func,
|
||||
BO.Div: BO.Div.sp_func,
|
||||
BO.Pow: BO.Pow.sp_func,
|
||||
# Elements | Elements
|
||||
BO.Add: BO.Add.sp_func,
|
||||
BO.Sub: BO.Sub.sp_func,
|
||||
BO.HadamMul: BO.Mul.sp_func,
|
||||
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
|
||||
BO.Atan2: lambda _: spu.radian,
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: BO.Mul.sp_func,
|
||||
BO.Cross: BO.Mul.sp_func,
|
||||
BO.VecVecOuter: BO.Mul.sp_func,
|
||||
# Matrix | Vector
|
||||
## -> A,b in Ax = b have units, and the equality must hold.
|
||||
## -> Therefore, A \ b must have the units [b]/[A].
|
||||
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
|
||||
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
|
||||
# Vector | Matrix
|
||||
BO.VecMatOuter: BO.Mul.sp_func,
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: BO.Mul.sp_func,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def jax_func(self):
|
||||
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
|
||||
## TODO: Scale the units of one side to the other.
|
||||
BO = BinaryOperation
|
||||
|
||||
return {
|
||||
# Number | Number
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.Mul: lambda exprs: exprs[0] * exprs[1],
|
||||
BO.Div: lambda exprs: exprs[0] / exprs[1],
|
||||
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
|
||||
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
|
||||
# Elements | Elements
|
||||
BO.Add: lambda exprs: exprs[0] + exprs[1],
|
||||
BO.Sub: lambda exprs: exprs[0] - exprs[1],
|
||||
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
|
||||
# BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
|
||||
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: lambda exprs: jnp.dot(exprs[0], exprs[1]),
|
||||
BO.Cross: lambda exprs: jnp.cross(exprs[0], exprs[1]),
|
||||
BO.VecVecOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
|
||||
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
|
||||
# Vector | Matrix
|
||||
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
|
||||
}[self]
|
||||
|
@ -218,30 +336,12 @@ class BinaryOperation(enum.StrEnum):
|
|||
# - InfoFlow Transform
|
||||
####################
|
||||
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
|
||||
BO = BinaryOperation
|
||||
|
||||
info_largest = (
|
||||
info_l if info_l.output_shape_len > info_l.output_shape_len else info_l
|
||||
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression."""
|
||||
return info_l.operate_output(
|
||||
info_r,
|
||||
lambda a, b: self.sp_func([a, b]),
|
||||
lambda a, b: self.unit_func([a, b]),
|
||||
)
|
||||
info_any = info_largest
|
||||
return {
|
||||
# Number | * or * | Number
|
||||
BO.Add: info_largest,
|
||||
BO.Sub: info_largest,
|
||||
BO.Mul: info_largest,
|
||||
BO.Div: info_largest,
|
||||
BO.Pow: info_largest,
|
||||
BO.Atan2: info_largest,
|
||||
# Vector | Vector
|
||||
BO.VecVecDot: info_any,
|
||||
BO.Cross: info_any,
|
||||
# Matrix | Vector
|
||||
BO.MatVecDot: info_r,
|
||||
BO.LinSolve: info_r,
|
||||
BO.LsqSolve: info_r,
|
||||
# Matrix | Matrix
|
||||
BO.MatMatDot: info_any,
|
||||
}[self]
|
||||
|
||||
|
||||
####################
|
||||
|
@ -367,9 +467,9 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_l_value and has_expr_r_value and operation is not None:
|
||||
operation.sp_func([expr_l, expr_r])
|
||||
return operation.sp_func([expr_l, expr_r])
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -396,7 +496,7 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_l and has_expr_r:
|
||||
return (expr_l | expr_r).compose_within(
|
||||
operation.jax_func,
|
||||
enclosing_func=operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -17,13 +17,13 @@
|
|||
"""Declares `TransformMathNode`."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
@ -51,6 +51,9 @@ class TransformOperation(enum.StrEnum):
|
|||
# Covariant Transform
|
||||
FreqToVacWL = enum.auto()
|
||||
VacWLToFreq = enum.auto()
|
||||
ConvertIdxUnit = enum.auto()
|
||||
SetIdxUnit = enum.auto()
|
||||
FirstColToFirstIdx = enum.auto()
|
||||
|
||||
# Fold
|
||||
IntDimToComplex = enum.auto()
|
||||
|
@ -58,8 +61,8 @@ class TransformOperation(enum.StrEnum):
|
|||
DimsToMat = enum.auto()
|
||||
|
||||
# Fourier
|
||||
FFT1D = enum.auto()
|
||||
InvFFT1D = enum.auto()
|
||||
FT1D = enum.auto()
|
||||
InvFT1D = enum.auto()
|
||||
|
||||
# TODO: Affine
|
||||
## TODO
|
||||
|
@ -74,17 +77,22 @@ class TransformOperation(enum.StrEnum):
|
|||
# Covariant Transform
|
||||
TO.FreqToVacWL: '𝑓 → λᵥ',
|
||||
TO.VacWLToFreq: 'λᵥ → 𝑓',
|
||||
TO.ConvertIdxUnit: 'Convert Dim',
|
||||
TO.SetIdxUnit: 'Set Dim',
|
||||
TO.FirstColToFirstIdx: '1st Col → Dim',
|
||||
# Fold
|
||||
TO.IntDimToComplex: '→ ℂ',
|
||||
TO.DimToVec: '→ Vector',
|
||||
TO.DimsToMat: '→ Matrix',
|
||||
## TODO: Vector to new last-dim integer
|
||||
## TODO: Matrix to two last-dim integers
|
||||
# Fourier
|
||||
TO.FFT1D: 't → 𝑓',
|
||||
TO.InvFFT1D: '𝑓 → t',
|
||||
TO.FT1D: '→ 𝑓',
|
||||
TO.InvFT1D: '𝑓 →',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> str:
|
||||
def to_icon(_: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
||||
|
@ -98,121 +106,216 @@ class TransformOperation(enum.StrEnum):
|
|||
)
|
||||
|
||||
####################
|
||||
# - Ops from Shape
|
||||
# - Methods
|
||||
####################
|
||||
@property
|
||||
def num_dim_inputs(self) -> None:
|
||||
"""The number of axes that should be passed as inputs to `func_jax` when evaluating it.
|
||||
|
||||
Especially useful for `ParamFlow`, when deciding whether to pass an integer-axis argument based on a user-selected dimension.
|
||||
"""
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: 1,
|
||||
TO.VacWLToFreq: 1,
|
||||
TO.ConvertIdxUnit: 1,
|
||||
TO.SetIdxUnit: 1,
|
||||
TO.FirstColToFirstIdx: 0,
|
||||
# Fold
|
||||
TO.IntDimToComplex: 0,
|
||||
TO.DimToVec: 0,
|
||||
TO.DimsToMat: 0,
|
||||
## TODO: Vector to new last-dim integer
|
||||
## TODO: Matrix to two last-dim integers
|
||||
# Fourier
|
||||
TO.FT1D: 1,
|
||||
TO.InvFT1D: 1,
|
||||
}[self]
|
||||
|
||||
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||
TO = TransformOperation
|
||||
match self:
|
||||
case TO.FreqToVacWL | TO.FT1D:
|
||||
return [
|
||||
dim
|
||||
for dim in info.dims
|
||||
if dim.physical_type is spux.PhysicalType.Freq
|
||||
]
|
||||
|
||||
case TO.VacWLToFreq | TO.InvFT1D:
|
||||
return [
|
||||
dim
|
||||
for dim in info.dims
|
||||
if dim.physical_type is spux.PhysicalType.Length
|
||||
]
|
||||
|
||||
case TO.ConvertIdxUnit | TO.SetIdxUnit:
|
||||
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
|
||||
|
||||
## ColDimToComplex: Implicit Last Dimension
|
||||
## DimToVec: Implicit Last Dimension
|
||||
## DimsToMat: Implicit Last 2 Dimensions
|
||||
|
||||
case TO.FT1D | TO.InvFT1D:
|
||||
# Filter by Axis Uniformity
|
||||
## -> FT requires uniform axis (aka. must be RangeFlow).
|
||||
## -> NOTE: If FT isn't popping up, check ExtractDataNode.
|
||||
return [dim for dim in info.dims if info.is_idx_uniform(dim)]
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]:
|
||||
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
|
||||
TO = TransformOperation
|
||||
operations = []
|
||||
|
||||
# Covariant Transform
|
||||
## Freq <-> VacWL
|
||||
for dim in info.dims:
|
||||
if dim.physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.FreqToVacWL)
|
||||
## Freq -> VacWL
|
||||
if TO.FreqToVacWL.valid_dims(info):
|
||||
operations += [TO.FreqToVacWL]
|
||||
|
||||
if dim.physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.VacWLToFreq)
|
||||
## VacWL -> Freq
|
||||
if TO.VacWLToFreq.valid_dims(info):
|
||||
operations += [TO.VacWLToFreq]
|
||||
|
||||
## Convert Index Unit
|
||||
if TO.ConvertIdxUnit.valid_dims(info):
|
||||
operations += [TO.ConvertIdxUnit]
|
||||
|
||||
if TO.SetIdxUnit.valid_dims(info):
|
||||
operations += [TO.SetIdxUnit]
|
||||
|
||||
## Column to First Index (Array)
|
||||
if (
|
||||
len(info.dims) == 2 # noqa: PLR2004
|
||||
and info.first_dim.mathtype is spux.MathType.Integer
|
||||
and info.last_dim.mathtype is spux.MathType.Integer
|
||||
and info.output.shape_len == 0
|
||||
):
|
||||
operations += [TO.FirstColToFirstIdx]
|
||||
|
||||
# Fold
|
||||
## (Last) Int Dim (=2) to Complex
|
||||
if len(info.dims) >= 1:
|
||||
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
|
||||
operations.append(TO.IntDimToComplex)
|
||||
## Last Dim -> Complex
|
||||
if (
|
||||
info.dims
|
||||
# Output is Int|Rat|Real
|
||||
and (
|
||||
info.output.mathtype
|
||||
in [spux.MathType.Integer, spux.MathType.Rational, spux.MathType.Real]
|
||||
)
|
||||
# Last Axis is Integer of Length 2
|
||||
and info.last_dim.mathtype is spux.MathType.Integer
|
||||
and info.has_idx_labels(info.last_dim)
|
||||
and len(info.dims[info.last_dim]) == 2 # noqa: PLR2004
|
||||
):
|
||||
operations += [TO.IntDimToComplex]
|
||||
|
||||
## To Vector
|
||||
if len(info.dims) >= 1:
|
||||
operations.append(TO.DimToVec)
|
||||
## Last Dim -> Vector
|
||||
if len(info.dims) >= 1 and info.output.shape_len == 0:
|
||||
operations += [TO.DimToVec]
|
||||
|
||||
## To Matrix
|
||||
if len(info.dims) >= 2: # noqa: PLR2004
|
||||
operations.append(TO.DimsToMat)
|
||||
## Last Dim -> Matrix
|
||||
if len(info.dims) >= 2 and info.output.shape_len == 0: # noqa: PLR2004
|
||||
operations += [TO.DimsToMat]
|
||||
|
||||
# Fourier
|
||||
## 1D Fourier
|
||||
if info.dims:
|
||||
last_physical_type = info.last_dim.physical_type
|
||||
if last_physical_type == spux.PhysicalType.Time:
|
||||
operations.append(TO.FFT1D)
|
||||
if last_physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.InvFFT1D)
|
||||
if TO.FT1D.valid_dims(info):
|
||||
operations += [TO.FT1D]
|
||||
|
||||
if TO.InvFT1D.valid_dims(info):
|
||||
operations += [TO.InvFT1D]
|
||||
|
||||
return operations
|
||||
|
||||
####################
|
||||
# - Function Properties
|
||||
####################
|
||||
@property
|
||||
def sp_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
# Fold
|
||||
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
|
||||
TO.DimToVec: lambda expr: expr,
|
||||
TO.DimsToMat: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: sp.fourier_transform(
|
||||
expr, sim_symbols.t, sim_symbols.freq
|
||||
),
|
||||
TO.InvFFT1D: lambda expr: sp.fourier_transform(
|
||||
expr, sim_symbols.freq, sim_symbols.t
|
||||
),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def jax_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
## -> Freq <-> WL is a rescale (noop) AND flip (not noop).
|
||||
TO.FreqToVacWL: lambda expr, axis: jnp.flip(expr, axis=axis),
|
||||
TO.VacWLToFreq: lambda expr, axis: jnp.flip(expr, axis=axis),
|
||||
TO.ConvertIdxUnit: lambda expr: expr,
|
||||
TO.SetIdxUnit: lambda expr: expr,
|
||||
TO.FirstColToFirstIdx: lambda expr: jnp.delete(expr, 0, axis=1),
|
||||
# Fold
|
||||
## -> To Complex: With a little imagination, this is a noop :)
|
||||
## -> **Requires** dims[-1] to be integer-indexed w/length of 2.
|
||||
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(),
|
||||
## -> To Complex: This should generally be a no-op.
|
||||
TO.IntDimToComplex: lambda expr: jnp.squeeze(
|
||||
expr.view(dtype=jnp.complex64), axis=-1
|
||||
),
|
||||
TO.DimToVec: lambda expr: expr,
|
||||
TO.DimsToMat: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: jnp.fft(expr),
|
||||
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
|
||||
TO.FT1D: lambda expr, axis: jnp.fft(expr, axis=axis),
|
||||
TO.InvFT1D: lambda expr, axis: jnp.ifft(expr, axis=axis),
|
||||
}[self]
|
||||
|
||||
def transform_info(
|
||||
self,
|
||||
info: ct.InfoFlow | None,
|
||||
data: jtyp.Shaped[jtyp.Array, '...'] | None = None,
|
||||
info: ct.InfoFlow,
|
||||
dim: sim_symbols.SimSymbol | None = None,
|
||||
data_col: jtyp.Shaped[jtyp.Array, ' size'] | None = None,
|
||||
new_dim_name: str | None = None,
|
||||
unit: spux.Unit | None = None,
|
||||
) -> ct.InfoFlow | None:
|
||||
physical_type: spux.PhysicalType | None = None,
|
||||
) -> ct.InfoFlow:
|
||||
TO = TransformOperation
|
||||
if not info.dims:
|
||||
return None
|
||||
return {
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda: info.replace_dim(
|
||||
(f_dim := info.last_dim),
|
||||
(f_dim := dim),
|
||||
[
|
||||
sim_symbols.wl(spu.nanometer),
|
||||
sim_symbols.wl(unit),
|
||||
info.dims[f_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spu.nanometer,
|
||||
new_unit=unit,
|
||||
),
|
||||
],
|
||||
),
|
||||
TO.VacWLToFreq: lambda: info.replace_dim(
|
||||
(wl_dim := info.last_dim),
|
||||
(wl_dim := dim),
|
||||
[
|
||||
sim_symbols.freq(spux.THz),
|
||||
sim_symbols.freq(unit),
|
||||
info.dims[wl_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spux.THz,
|
||||
new_unit=unit,
|
||||
),
|
||||
],
|
||||
),
|
||||
TO.ConvertIdxUnit: lambda: info.replace_dim(
|
||||
dim,
|
||||
dim.update(unit=unit),
|
||||
(
|
||||
info.dims[dim].rescale_to_unit(unit)
|
||||
if info.has_idx_discrete(dim)
|
||||
else None ## Continuous -- dim SimSymbol already scaled
|
||||
),
|
||||
),
|
||||
TO.SetIdxUnit: lambda: info.replace_dim(
|
||||
dim,
|
||||
dim.update(
|
||||
sym_name=new_dim_name, physical_type=physical_type, unit=unit
|
||||
),
|
||||
(
|
||||
info.dims[dim].correct_unit(unit)
|
||||
if info.has_idx_discrete(dim)
|
||||
else None ## Continuous -- dim SimSymbol already scaled
|
||||
),
|
||||
),
|
||||
TO.FirstColToFirstIdx: lambda: info.replace_dim(
|
||||
info.first_dim,
|
||||
info.first_dim.update(
|
||||
mathtype=spux.MathType.from_jax_array(data_col),
|
||||
unit=unit,
|
||||
),
|
||||
ct.ArrayFlow(values=data_col, unit=unit),
|
||||
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
|
||||
# Fold
|
||||
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
|
||||
mathtype=spux.MathType.Complex
|
||||
|
@ -220,21 +323,31 @@ class TransformOperation(enum.StrEnum):
|
|||
TO.DimToVec: lambda: info.fold_last_input(),
|
||||
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
|
||||
# Fourier
|
||||
TO.FFT1D: lambda: info.replace_dim(
|
||||
info.last_dim,
|
||||
TO.FT1D: lambda: info.replace_dim(
|
||||
dim,
|
||||
[
|
||||
sim_symbols.freq(spux.THz),
|
||||
None,
|
||||
# FT'ed Unit: Reciprocal of the Original Unit
|
||||
dim.update(
|
||||
unit=1 / dim.unit if dim.unit is not None else 1
|
||||
), ## TODO: Okay to not scale interval?
|
||||
# FT'ed Bounds: Reciprocal of the Original Unit
|
||||
info.dims[dim].bound_fourier_transform,
|
||||
],
|
||||
),
|
||||
TO.InvFFT1D: info.replace_dim(
|
||||
TO.InvFT1D: lambda: info.replace_dim(
|
||||
info.last_dim,
|
||||
[
|
||||
sim_symbols.t(spu.second),
|
||||
None,
|
||||
# FT'ed Unit: Reciprocal of the Original Unit
|
||||
dim.update(
|
||||
unit=1 / dim.unit if dim.unit is not None else 1
|
||||
), ## TODO: Okay to not scale interval?
|
||||
# FT'ed Bounds: Reciprocal of the Original Unit
|
||||
## -> Note the midpoint may revert to 0.
|
||||
## -> See docs for `RangeFlow.bound_inv_fourier_transform` for more.
|
||||
info.dims[dim].bound_inv_fourier_transform,
|
||||
],
|
||||
),
|
||||
}.get(self, lambda: info)()
|
||||
}[self]()
|
||||
|
||||
|
||||
####################
|
||||
|
@ -274,7 +387,6 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
info_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
@ -304,45 +416,125 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
return [
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(
|
||||
TransformOperation.by_element_shape(self.expr_info)
|
||||
TransformOperation.by_info(self.expr_info)
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
####################
|
||||
# - Properties: Dimension Selection
|
||||
####################
|
||||
active_dim: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_dims(),
|
||||
cb_depends_on={'operation', 'expr_info'},
|
||||
)
|
||||
|
||||
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_info is not None and self.operation is not None:
|
||||
return [
|
||||
(dim.name, dim.name_pretty, dim.name, '', i)
|
||||
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
|
||||
]
|
||||
return []
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
|
||||
def dim(self) -> sim_symbols.SimSymbol | None:
|
||||
if self.expr_info is not None and self.active_dim is not None:
|
||||
return self.expr_info.dim_by_name(self.active_dim)
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Properties: New Dimension Properties
|
||||
####################
|
||||
new_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Expr
|
||||
)
|
||||
new_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
active_new_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(),
|
||||
cb_depends_on={'dim', 'new_physical_type'},
|
||||
)
|
||||
|
||||
def search_units(self) -> list[ct.BLEnumElement]:
|
||||
if self.dim is not None:
|
||||
if self.dim.physical_type is not spux.PhysicalType.NonPhysical:
|
||||
unit_name = sp.sstr(self.dim.unit)
|
||||
return [
|
||||
(
|
||||
sp.sstr(unit),
|
||||
spux.sp_to_str(unit),
|
||||
sp.sstr(unit),
|
||||
'',
|
||||
0,
|
||||
)
|
||||
for unit in self.dim.physical_type.valid_units
|
||||
]
|
||||
|
||||
if self.dim.unit is not None:
|
||||
unit_name = sp.sstr(self.dim.unit)
|
||||
return [
|
||||
(
|
||||
unit_name,
|
||||
spux.sp_to_str(self.dim.unit),
|
||||
unit_name,
|
||||
'',
|
||||
0,
|
||||
)
|
||||
]
|
||||
if self.new_physical_type is not spux.PhysicalType.NonPhysical:
|
||||
return [
|
||||
(
|
||||
sp.sstr(unit),
|
||||
spux.sp_to_str(unit),
|
||||
sp.sstr(unit),
|
||||
'',
|
||||
i,
|
||||
)
|
||||
for i, unit in enumerate(self.new_physical_type.valid_units)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'active_new_unit'})
|
||||
def new_unit(self) -> spux.Unit:
|
||||
if self.active_new_unit is not None:
|
||||
return spux.unit_str_to_unit(self.active_new_unit)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
if self.operation is not None:
|
||||
return 'Transform: ' + TransformOperation.to_name(self.operation)
|
||||
return 'T: ' + TransformOperation.to_name(self.operation)
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
if self.operation is not None and self.operation.num_dim_inputs == 1:
|
||||
TO = TransformOperation
|
||||
layout.prop(self, self.blfields['active_dim'], text='')
|
||||
|
||||
if self.operation in [TO.ConvertIdxUnit, TO.SetIdxUnit]:
|
||||
col = layout.column(align=True)
|
||||
if self.operation is TransformOperation.ConvertIdxUnit:
|
||||
col.prop(self, self.blfields['active_new_unit'], text='')
|
||||
|
||||
if self.operation is TransformOperation.SetIdxUnit:
|
||||
col.prop(self, self.blfields['new_physical_type'], text='')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['new_name'], text='')
|
||||
row.prop(self, self.blfields['active_new_unit'], text='')
|
||||
|
||||
####################
|
||||
# - Compute: Func / Array
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Value,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
)
|
||||
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
has_expr_value = not ct.FlowSignal.check(expr)
|
||||
|
||||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_value and operation is not None:
|
||||
return operation.sp_func(expr)
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Func,
|
||||
|
@ -354,54 +546,103 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
lazy_func = input_sockets['Expr']
|
||||
|
||||
has_expr = not ct.FlowSignal.check(expr)
|
||||
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||
|
||||
if has_expr and operation is not None:
|
||||
return expr.compose_within(
|
||||
if has_lazy_func and operation is not None:
|
||||
return lazy_func.compose_within(
|
||||
operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Info|Params
|
||||
# - FlowKind.Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'operation'},
|
||||
props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
},
|
||||
)
|
||||
def compute_info(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
|
||||
operation = props['operation']
|
||||
info = input_sockets['Expr']
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
|
||||
dim = props['dim']
|
||||
new_name = props['new_name']
|
||||
new_unit = props['new_unit']
|
||||
new_physical_type = props['new_physical_type']
|
||||
if has_info and operation is not None:
|
||||
transformed_info = operation.transform_info(info)
|
||||
|
||||
if transformed_info is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
return transformed_info
|
||||
# First Column to First Index
|
||||
## -> We have to evaluate the lazy function at this point.
|
||||
## -> It's the only way to get at the column data.
|
||||
if operation is TransformOperation.FirstColToFirstIdx:
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||
has_params = not ct.FlowSignal.check(lazy_func)
|
||||
|
||||
if has_lazy_func and has_params and not params.symbols:
|
||||
data = lazy_func.realize(params)
|
||||
if data.shape is not None and len(data.shape) == 2:
|
||||
data_col = data[:, 0]
|
||||
return operation.transform_info(info, data_col=data_col)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Check Not-Yet-Updated Dimension
|
||||
## - Operation changes before dimensions.
|
||||
## - If InfoFlow is requested in this interim, big problem.
|
||||
if dim is None and operation.num_dim_inputs > 0:
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
return operation.transform_info(
|
||||
info,
|
||||
dim=dim,
|
||||
new_dim_name=new_name,
|
||||
unit=new_unit,
|
||||
physical_type=new_physical_type,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
props={'operation', 'dim'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Params},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.Params, ct.FlowKind.Info}},
|
||||
)
|
||||
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
if has_params:
|
||||
return input_sockets['Expr']
|
||||
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
operation = props['operation']
|
||||
dim = props['dim']
|
||||
if has_info and has_params and operation is not None:
|
||||
# Axis Required: Insert by-Dimension
|
||||
## -> Some transformations ex. FT require setting an axis.
|
||||
## -> The user selects which dimension the op should be done along.
|
||||
## -> This dimension is converted to an axis integer.
|
||||
## -> Finally, we pass the argument via params.
|
||||
if operation.num_dim_inputs == 1:
|
||||
axis = info.dim_axis(dim) if dim is not None else None
|
||||
return params.compose_within(enclosing_func_args=[axis])
|
||||
|
||||
return params
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import bpy
|
|||
import jaxtyping as jtyp
|
||||
import matplotlib.axis as mpl_ax
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
@ -74,8 +75,33 @@ class VizMode(enum.StrEnum):
|
|||
SqueezedHeatmap2D = enum.auto()
|
||||
Heatmap3D = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
VizMode.BoxPlot1D: 'Box Plot',
|
||||
VizMode.Curve2D: 'Curve',
|
||||
VizMode.Points2D: 'Points',
|
||||
VizMode.Bar: 'Bar',
|
||||
VizMode.Curves2D: 'Curves',
|
||||
VizMode.FilledCurves2D: 'Filled Curves',
|
||||
VizMode.Heatmap2D: 'Heatmap',
|
||||
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
|
||||
VizMode.Heatmap3D: 'Heatmap (3D)',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> ct.BLIcon:
|
||||
return ''
|
||||
|
||||
####################
|
||||
# - Validity
|
||||
####################
|
||||
@staticmethod
|
||||
def by_info(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||
"""Given the input `InfoFlow`, deduce which visualization modes are valid to use with the described data."""
|
||||
Z = spux.MathType.Integer
|
||||
R = spux.MathType.Real
|
||||
VM = VizMode
|
||||
|
@ -102,15 +128,18 @@ class VizMode(enum.StrEnum):
|
|||
],
|
||||
}.get(
|
||||
(
|
||||
tuple([dim.mathtype for dim in info.dims.values()]),
|
||||
tuple([dim.mathtype for dim in info.dims]),
|
||||
(info.output.rows, info.output.cols, info.output.mathtype),
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_plotter(
|
||||
value: typ.Self,
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def mpl_plotter(
|
||||
self,
|
||||
) -> typ.Callable[
|
||||
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
|
||||
]:
|
||||
|
@ -124,25 +153,7 @@ class VizMode(enum.StrEnum):
|
|||
VizMode.Heatmap2D: image_ops.plot_heatmap_2d,
|
||||
# NO PLOTTER: VizMode.SqueezedHeatmap2D
|
||||
# NO PLOTTER: VizMode.Heatmap3D
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
VizMode.BoxPlot1D: 'Box Plot',
|
||||
VizMode.Curve2D: 'Curve',
|
||||
VizMode.Points2D: 'Points',
|
||||
VizMode.Bar: 'Bar',
|
||||
VizMode.Curves2D: 'Curves',
|
||||
VizMode.FilledCurves2D: 'Filled Curves',
|
||||
VizMode.Heatmap2D: 'Heatmap',
|
||||
VizMode.SqueezedHeatmap2D: 'Heatmap (Squeezed)',
|
||||
VizMode.Heatmap3D: 'Heatmap (3D)',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> ct.BLIcon:
|
||||
return ''
|
||||
}[self]
|
||||
|
||||
|
||||
class VizTarget(enum.StrEnum):
|
||||
|
@ -181,6 +192,10 @@ class VizTarget(enum.StrEnum):
|
|||
return ''
|
||||
|
||||
|
||||
sym_x_um = sim_symbols.space_x(spu.um)
|
||||
x_um = sym_x_um.sp_symbol
|
||||
|
||||
|
||||
class VizNode(base.MaxwellSimNode):
|
||||
"""Node for visualizing simulation data, by querying its monitors.
|
||||
|
||||
|
@ -188,7 +203,6 @@ class VizNode(base.MaxwellSimNode):
|
|||
|
||||
Attributes:
|
||||
colormap: Colormap to apply to 0..1 output.
|
||||
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.Viz
|
||||
|
@ -201,8 +215,8 @@ class VizNode(base.MaxwellSimNode):
|
|||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Func,
|
||||
default_symbols=[sim_symbols.x],
|
||||
default_value=2 * sim_symbols.x.sp_symbol,
|
||||
default_symbols=[sym_x_um],
|
||||
default_value=sp.exp(-(x_um**2)),
|
||||
),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
|
@ -240,16 +254,21 @@ class VizNode(base.MaxwellSimNode):
|
|||
|
||||
return None
|
||||
|
||||
viz_mode: enum.StrEnum = bl_cache.BLField(
|
||||
viz_mode: VizMode = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_viz_modes(),
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
viz_target: enum.StrEnum = bl_cache.BLField(
|
||||
viz_target: VizTarget = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_targets(),
|
||||
cb_depends_on={'viz_mode'},
|
||||
)
|
||||
|
||||
# Mode-Dependent Properties
|
||||
# Plot
|
||||
plot_width: float = bl_cache.BLField(6.0, abs_min=0.1)
|
||||
plot_height: float = bl_cache.BLField(3.0, abs_min=0.1)
|
||||
plot_dpi: int = bl_cache.BLField(150, abs_min=25)
|
||||
|
||||
# Pixels
|
||||
colormap: image_ops.Colormap = bl_cache.BLField(
|
||||
image_ops.Colormap.Viridis,
|
||||
)
|
||||
|
@ -267,7 +286,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
VizMode.to_icon(viz_mode),
|
||||
i,
|
||||
)
|
||||
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info))
|
||||
for i, viz_mode in enumerate(VizMode.by_info(self.expr_info))
|
||||
]
|
||||
|
||||
return []
|
||||
|
@ -300,9 +319,22 @@ class VizNode(base.MaxwellSimNode):
|
|||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
||||
col.prop(self, self.blfields['viz_mode'], text='')
|
||||
col.prop(self, self.blfields['viz_target'], text='')
|
||||
|
||||
if self.viz_target in [VizTarget.Pixels, VizTarget.PixelsPlane]:
|
||||
col.prop(self, self.blfields['colormap'], text='')
|
||||
|
||||
if self.viz_target is VizTarget.Plot2D:
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Width/Height/DPI')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['plot_width'], text='')
|
||||
row.prop(self, self.blfields['plot_height'], text='')
|
||||
|
||||
row = col.row(align=True)
|
||||
col.prop(self, self.blfields['plot_dpi'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
|
@ -320,16 +352,16 @@ class VizNode(base.MaxwellSimNode):
|
|||
has_info = not ct.FlowSignal.check(info)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
# Provide Sockets for Symbol Realization
|
||||
# Declare Loose Sockets that Realize Symbols
|
||||
## -> This happens if Params contains not-yet-realized symbols.
|
||||
if has_info and has_params and params.symbols:
|
||||
if set(self.loose_input_sockets) != {
|
||||
dim.name for dim in params.symbols if dim in info.dims
|
||||
sym.name for sym in params.symbols if sym in info.dims
|
||||
}:
|
||||
self.loose_input_sockets = {
|
||||
dim_name: sockets.ExprSocketDef(**expr_info)
|
||||
for dim_name, expr_info in params.sym_expr_infos(
|
||||
info, use_range=True
|
||||
use_range=True
|
||||
).items()
|
||||
}
|
||||
|
||||
|
@ -343,7 +375,14 @@ class VizNode(base.MaxwellSimNode):
|
|||
'Preview',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
props={
|
||||
'viz_mode',
|
||||
'viz_target',
|
||||
'colormap',
|
||||
'plot_width',
|
||||
'plot_height',
|
||||
'plot_dpi',
|
||||
},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
|
@ -359,7 +398,14 @@ class VizNode(base.MaxwellSimNode):
|
|||
#####################
|
||||
@events.on_show_plot(
|
||||
managed_objs={'plot'},
|
||||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
props={
|
||||
'viz_mode',
|
||||
'viz_target',
|
||||
'colormap',
|
||||
'plot_width',
|
||||
'plot_height',
|
||||
'plot_dpi',
|
||||
},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
|
@ -370,7 +416,6 @@ class VizNode(base.MaxwellSimNode):
|
|||
def on_show_plot(
|
||||
self, managed_objs, props, input_sockets, loose_input_sockets
|
||||
) -> None:
|
||||
# Retrieve Inputs
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
@ -378,43 +423,51 @@ class VizNode(base.MaxwellSimNode):
|
|||
has_info = not ct.FlowSignal.check(info)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
if (
|
||||
not has_info
|
||||
or not has_params
|
||||
or props['viz_mode'] is None
|
||||
or props['viz_target'] is None
|
||||
):
|
||||
return
|
||||
|
||||
# Compute Ranges for Symbols from Loose Sockets
|
||||
## -> In a quite nice turn of events, all this is cached lookups.
|
||||
## -> ...Unless something changed, in which case, well. It changed.
|
||||
symbol_array_values = {
|
||||
sim_syms: (
|
||||
loose_input_sockets[sim_syms]
|
||||
.rescale_to_unit(sim_syms.unit)
|
||||
.realize_array
|
||||
)
|
||||
for sim_syms in params.sorted_symbols
|
||||
plot = managed_objs['plot']
|
||||
viz_mode = props['viz_mode']
|
||||
viz_target = props['viz_target']
|
||||
if has_info and has_params and viz_mode is not None and viz_target is not None:
|
||||
# Realize Data w/Realized Symbols
|
||||
## -> The loose input socket values are user-selected symbol values.
|
||||
## -> These expressions are used to realize the lazy data.
|
||||
## -> `.realize()` ensures all ex. units are correctly conformed.
|
||||
realized_syms = {
|
||||
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
|
||||
}
|
||||
data = lazy_func.realize(params, symbol_values=symbol_array_values)
|
||||
output_data = lazy_func.realize(params, symbol_values=realized_syms)
|
||||
|
||||
# Replace InfoFlow Indices w/Realized Symbolic Ranges
|
||||
## -> This ensures correct axis scaling.
|
||||
if params.symbols:
|
||||
info = info.replace_dims(symbol_array_values)
|
||||
data = {
|
||||
dim: (
|
||||
realized_syms[dim].values
|
||||
if dim in realized_syms
|
||||
else info.dims[dim]
|
||||
)
|
||||
for dim in info.dims
|
||||
} | {info.output: output_data}
|
||||
|
||||
match props['viz_target']:
|
||||
# Match Viz Type & Perform Visualization
|
||||
## -> Viz Target determines how to plot.
|
||||
## -> Viz Mode may help select a particular plotting method.
|
||||
## -> Other parameters may be uses, depending on context.
|
||||
match viz_target:
|
||||
case VizTarget.Plot2D:
|
||||
managed_objs['plot'].mpl_plot_to_image(
|
||||
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
|
||||
plot_width = props['plot_width']
|
||||
plot_height = props['plot_height']
|
||||
plot_dpi = props['plot_dpi']
|
||||
plot.mpl_plot_to_image(
|
||||
lambda ax: viz_mode.mpl_plotter(data, ax),
|
||||
width_inches=plot_width,
|
||||
height_inches=plot_height,
|
||||
dpi=plot_dpi,
|
||||
bl_select=True,
|
||||
)
|
||||
|
||||
case VizTarget.Pixels:
|
||||
managed_objs['plot'].map_2d_to_image(
|
||||
colormap = props['colormap']
|
||||
if colormap is not None:
|
||||
plot.map_2d_to_image(
|
||||
data,
|
||||
colormap=props['colormap'],
|
||||
colormap=colormap,
|
||||
bl_select=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -610,7 +610,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
## -> Anyone needing results will need to wait on preinit().
|
||||
return ct.FlowSignal.FlowInitializing
|
||||
|
||||
if optional:
|
||||
# if optional:
|
||||
return ct.FlowSignal.NoFlow
|
||||
|
||||
msg = f'{self.sim_node_name}: Input socket "{input_socket_name}" cannot be computed, as it is not an active input socket'
|
||||
|
@ -659,11 +659,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
return output_socket_methods[0](self)
|
||||
|
||||
# Auxiliary Fallbacks
|
||||
if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
|
||||
return ct.FlowSignal.NoFlow
|
||||
# if optional or kind in [ct.FlowKind.Info, ct.FlowKind.Params]:
|
||||
# return ct.FlowSignal.NoFlow
|
||||
|
||||
msg = f'No output method for ({output_socket_name}, {kind})'
|
||||
raise ValueError(msg)
|
||||
# msg = f'No output method for ({output_socket_name}, {kind})'
|
||||
# raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Event Trigger
|
||||
|
|
|
@ -30,6 +30,7 @@ class ExprConstantNode(base.MaxwellSimNode):
|
|||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Func,
|
||||
show_name_selector=True,
|
||||
),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import bl_cache, sci_constants
|
||||
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
|
@ -30,15 +33,19 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
bl_label = 'Scientific Constant'
|
||||
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Value': sockets.ExprSocketDef(),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
sci_constant: str = bl_cache.BLField(
|
||||
use_symbol: bool = bl_cache.BLField(False)
|
||||
|
||||
sci_constant_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerU
|
||||
)
|
||||
sci_constant_str: str = bl_cache.BLField(
|
||||
'',
|
||||
prop_ui=True,
|
||||
str_cb=lambda self, _, edit_text: self.search_sci_constants(edit_text),
|
||||
)
|
||||
|
||||
|
@ -52,27 +59,139 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
if edit_text.lower() in name.lower()
|
||||
]
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
|
||||
def sci_constant(self) -> spux.SympyExpr | None:
|
||||
"""Retrieve the expression for the scientific constant."""
|
||||
return sci_constants.SCI_CONSTANTS.get(self.sci_constant_str)
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sci_constant_str'})
|
||||
def sci_constant_info(self) -> spux.SympyExpr | None:
|
||||
"""Retrieve the information for the selected scientific constant."""
|
||||
return sci_constants.SCI_CONSTANTS_INFO.get(self.sci_constant_str)
|
||||
|
||||
@bl_cache.cached_bl_property(
|
||||
depends_on={'sci_constant', 'sci_constant_info', 'sci_constant_name'}
|
||||
)
|
||||
def sci_constant_sym(self) -> spux.SympyExpr | None:
|
||||
"""Retrieve a symbol for the scientific constant."""
|
||||
if self.sci_constant is not None and self.sci_constant_info is not None:
|
||||
unit = self.sci_constant_info['units']
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=self.sci_constant_name,
|
||||
mathtype=spux.MathType.from_expr(self.sci_constant),
|
||||
# physical_type= ## TODO: Formalize unit w/o physical_type
|
||||
unit=unit,
|
||||
is_constant=True,
|
||||
)
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||
col.prop(self, self.blfields['sci_constant'], text='')
|
||||
col.prop(self, self.blfields['sci_constant_str'], text='')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Assign Symbol')
|
||||
col.prop(self, self.blfields['sci_constant_name'], text='')
|
||||
col.prop(self, self.blfields['use_symbol'], text='Use Symbol', toggle=True)
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||
if self.sci_constant:
|
||||
col.label(
|
||||
text=f'Units: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["units"]}'
|
||||
)
|
||||
col.label(
|
||||
text=f'Uncertainty: {sci_constants.SCI_CONSTANTS_INFO[self.sci_constant]["uncertainty"]}'
|
||||
)
|
||||
box = col.box()
|
||||
split = box.split(factor=0.25, align=True)
|
||||
|
||||
# Left: Units
|
||||
_col = split.column(align=True)
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text='Src')
|
||||
|
||||
if self.sci_constant_info:
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text='Unit')
|
||||
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text='Err')
|
||||
|
||||
# Right: Values
|
||||
_col = split.column(align=True)
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text='CODATA2018')
|
||||
|
||||
if self.sci_constant_info:
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text=f'{self.sci_constant_info["units"]}')
|
||||
|
||||
row = _col.row(align=True)
|
||||
# row.alignment = 'CENTER'
|
||||
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
|
||||
|
||||
####################
|
||||
# - Output
|
||||
####################
|
||||
@events.computes_output_socket('Value', props={'sci_constant'})
|
||||
def compute_value(self, props: dict) -> typ.Any:
|
||||
return sci_constants.SCI_CONSTANTS[props['sci_constant']]
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
props={'use_symbol', 'sci_constant', 'sci_constant_sym'},
|
||||
)
|
||||
def compute_value(self, props) -> typ.Any:
|
||||
sci_constant = props['sci_constant']
|
||||
sci_constant_sym = props['sci_constant_sym']
|
||||
|
||||
if props['use_symbol'] and sci_constant_sym is not None:
|
||||
return sci_constant_sym.sp_symbol
|
||||
|
||||
if sci_constant is not None:
|
||||
return sci_constant
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'sci_constant', 'sci_constant_sym'},
|
||||
)
|
||||
def compute_lazy_func(self, props) -> typ.Any:
|
||||
sci_constant = props['sci_constant']
|
||||
sci_constant_sym = props['sci_constant_sym']
|
||||
|
||||
if sci_constant is not None:
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(
|
||||
[sci_constant_sym.sp_symbol], sci_constant_sym.sp_symbol, 'jax'
|
||||
),
|
||||
func_args=[sci_constant_sym],
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'sci_constant_sym'},
|
||||
)
|
||||
def compute_info(self, props: dict) -> typ.Any:
|
||||
sci_constant_sym = props['sci_constant_sym']
|
||||
|
||||
if sci_constant_sym is not None:
|
||||
return ct.InfoFlow(output=sci_constant_sym)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
props={'sci_constant'},
|
||||
)
|
||||
def compute_params(self, props: dict) -> typ.Any:
|
||||
sci_constant = props['sci_constant']
|
||||
|
||||
if sci_constant is not None:
|
||||
return ct.ParamsFlow(func_args=[sci_constant])
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -95,62 +95,35 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Info Guides
|
||||
####################
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
|
||||
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Data
|
||||
)
|
||||
output_mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
|
||||
output_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
output_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
||||
enum_cb=lambda self, _: self.search_units(self.output_physical_type),
|
||||
cb_depends_on={'output_physical_type'},
|
||||
)
|
||||
|
||||
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerA
|
||||
)
|
||||
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_0_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
||||
cb_depends_on={'dim_0_physical_type'},
|
||||
)
|
||||
|
||||
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerB
|
||||
)
|
||||
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_1_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
|
||||
cb_depends_on={'dim_1_physical_type'},
|
||||
)
|
||||
|
||||
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerC
|
||||
)
|
||||
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_2_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
|
||||
cb_depends_on={'dim_2_physical_type'},
|
||||
)
|
||||
|
||||
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerD
|
||||
)
|
||||
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
dim_4_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerE
|
||||
)
|
||||
dim_3_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
|
||||
cb_depends_on={'dim_3_physical_type'},
|
||||
dim_5_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerF
|
||||
)
|
||||
|
||||
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
|
||||
|
@ -161,19 +134,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
]
|
||||
return []
|
||||
|
||||
def dim(self, i: int):
|
||||
dim_name = getattr(self, f'dim_{i}_name')
|
||||
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
|
||||
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
|
||||
dim_unit = getattr(self, f'dim_{i}_unit')
|
||||
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=dim_name,
|
||||
mathtype=dim_mathtype,
|
||||
physical_type=dim_physical_type,
|
||||
unit=spux.unit_str_to_unit(dim_unit),
|
||||
)
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
|
@ -202,19 +162,21 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
"""Draw loaded properties."""
|
||||
for i in range(len(self.expr_info.dims)):
|
||||
col = layout.column(align=True)
|
||||
if self.expr_info is not None:
|
||||
for i in range(len(self.expr_info.dims)):
|
||||
col.prop(self, self.blfields[f'dim_{i}_name'], text=f'Dim {i}')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=f'Load Dim {i}')
|
||||
row.label(text='Output')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
|
||||
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
|
||||
|
||||
row.prop(self, self.blfields['output_name'], text='')
|
||||
row.prop(self, self.blfields['output_mathtype'], text='')
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
|
||||
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
|
||||
row.prop(self, self.blfields['output_physical_type'], text='')
|
||||
row.prop(self, self.blfields['output_unit'], text='')
|
||||
|
||||
####################
|
||||
# - FlowKind.Array|Func
|
||||
|
@ -271,7 +233,8 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
# Loaded
|
||||
props={'output_name', 'output_physical_type', 'output_unit'},
|
||||
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
|
||||
| {f'dim_{i}_name' for i in range(6)},
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Func},
|
||||
)
|
||||
|
@ -285,32 +248,31 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
A completely empty `ParamsFlow`, ready to be composed.
|
||||
"""
|
||||
expr = output_sockets['Expr']
|
||||
|
||||
has_expr_func = not ct.FlowSignal.check(expr)
|
||||
|
||||
if has_expr_func:
|
||||
data = expr.func_jax()
|
||||
|
||||
# Deduce Dimensionality
|
||||
_shape = data.shape
|
||||
shape = _shape if _shape is not None else ()
|
||||
dim_syms = [self.dim(i) for i in range(len(shape))]
|
||||
# Deduce Dimension Symbols
|
||||
## -> They are all chronically integer indices.
|
||||
## -> The FilterNode can be used to "steal" an index from the data.
|
||||
shape = data.shape if data.shape is not None else ()
|
||||
dims = {
|
||||
sim_symbols.idx(None).update(
|
||||
sym_name=props[f'dim_{i}_name'],
|
||||
interval_finite_z=(0, elements),
|
||||
interval_inf=(False, False),
|
||||
interval_closed=(True, True),
|
||||
): [str(j) for j in range(elements)]
|
||||
for i, elements in enumerate(shape)
|
||||
}
|
||||
|
||||
# Return InfoFlow
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
dim_sym: ct.RangeFlow(
|
||||
start=sp.S(0),
|
||||
stop=sp.S(shape[i] - 1),
|
||||
steps=shape[i],
|
||||
unit=self.dim(i).unit,
|
||||
)
|
||||
for i, dim_sym in enumerate(dim_syms)
|
||||
},
|
||||
dims=dims,
|
||||
output=sim_symbols.SimSymbol(
|
||||
sym_name=props['output_name'],
|
||||
mathtype=props['output_mathtype'],
|
||||
physical_type=props['output_physical_type'],
|
||||
unit=props['output_unit'],
|
||||
),
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -259,9 +259,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
|||
)
|
||||
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
|
||||
return ct.RangeFlow(
|
||||
start=props['freq_range'][0] / spux.THz,
|
||||
stop=props['freq_range'][1] / spux.THz,
|
||||
steps=0,
|
||||
start=spu.scale_to_unit(['freq_range'][0], spux.THz),
|
||||
stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
|
||||
scaling=ct.ScalingMode.Lin,
|
||||
unit=spux.THz,
|
||||
)
|
||||
|
@ -273,9 +272,8 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
|||
)
|
||||
def compute_valid_wls_lazy(self, props) -> sp.Expr:
|
||||
return ct.RangeFlow(
|
||||
start=props['wl_range'][0] / spu.nm,
|
||||
stop=props['wl_range'][0] / spu.nm,
|
||||
steps=0,
|
||||
start=spu.scale_to_unit(['wl_range'][0], spu.nm),
|
||||
stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
|
||||
scaling=ct.ScalingMode.Lin,
|
||||
unit=spu.nm,
|
||||
)
|
||||
|
|
|
@ -73,31 +73,90 @@ class ViewerNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Properties
|
||||
####################
|
||||
print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
|
||||
auto_plot: bool = bl_cache.BLField(False)
|
||||
auto_expr: bool = bl_cache.BLField(True)
|
||||
debug_mode: bool = bl_cache.BLField(False)
|
||||
|
||||
# Debug Mode
|
||||
console_print_kind: ct.FlowKind = bl_cache.BLField(ct.FlowKind.Value)
|
||||
auto_plot: bool = bl_cache.BLField(True)
|
||||
auto_3d_preview: bool = bl_cache.BLField(True)
|
||||
|
||||
####################
|
||||
# - Properties: Computed FlowKinds
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='Any',
|
||||
)
|
||||
def on_input_changed(self) -> None:
|
||||
self.input_flow = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def input_flow(self) -> dict[ct.FlowKind, typ.Any | None]:
|
||||
input_flow = {}
|
||||
|
||||
for flow_kind in list(ct.FlowKind):
|
||||
flow = self._compute_input('Any', kind=flow_kind)
|
||||
has_flow = not ct.FlowSignal.check(flow)
|
||||
|
||||
if has_flow:
|
||||
input_flow |= {flow_kind: flow}
|
||||
else:
|
||||
input_flow |= {flow_kind: None}
|
||||
|
||||
return input_flow
|
||||
|
||||
####################
|
||||
# - Property: Input Expression String Lines
|
||||
####################
|
||||
@bl_cache.cached_bl_property(depends_on={'input_flow'})
|
||||
def input_expr_str_entries(self) -> list[list[str]] | None:
|
||||
value = self.input_flow.get(ct.FlowKind.Value)
|
||||
|
||||
def sp_pretty(v: spux.SympyExpr) -> spux.SympyExpr:
|
||||
## sp.pretty makes new lines and wreaks havoc.
|
||||
return spux.sp_to_str(v.n(4))
|
||||
|
||||
if isinstance(value, spux.SympyType):
|
||||
if isinstance(value, sp.MatrixBase):
|
||||
return [
|
||||
[sp_pretty(value[row, col]) for col in range(value.shape[1])]
|
||||
for row in range(value.shape[0])
|
||||
]
|
||||
|
||||
return [[sp_pretty(value)]]
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||
layout.prop(self, self.blfields['print_kind'], text='')
|
||||
row = layout.row(align=True)
|
||||
|
||||
# Debug Mode On/Off
|
||||
row.prop(self, self.blfields['debug_mode'], text='Debug', toggle=True)
|
||||
|
||||
# Automatic Expression Printing
|
||||
row.prop(self, self.blfields['auto_expr'], text='Expr', toggle=True)
|
||||
|
||||
# Debug Mode Operators
|
||||
if self.debug_mode:
|
||||
layout.prop(self, self.blfields['console_print_kind'], text='')
|
||||
|
||||
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||
# Live Expression
|
||||
if self.debug_mode:
|
||||
layout.operator(ConsoleViewOperator.bl_idname, text='Console Print')
|
||||
|
||||
split = layout.split(factor=0.4)
|
||||
|
||||
# Split LHS
|
||||
col = split.column(align=False)
|
||||
col.label(text='Console')
|
||||
col.label(text='Plot')
|
||||
col.label(text='3D')
|
||||
|
||||
# Split RHS
|
||||
col = split.column(align=False)
|
||||
|
||||
## Console Options
|
||||
col.operator(ConsoleViewOperator.bl_idname, text='Print')
|
||||
|
||||
## Plot Options
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['auto_plot'], text='Plot', toggle=True)
|
||||
|
@ -109,7 +168,43 @@ class ViewerNode(base.MaxwellSimNode):
|
|||
|
||||
## 3D Preview Options
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True)
|
||||
row.prop(
|
||||
self, self.blfields['auto_3d_preview'], text='3D Preview', toggle=True
|
||||
)
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
|
||||
# Live Expression
|
||||
if self.auto_expr and self.input_expr_str_entries is not None:
|
||||
box = layout.box()
|
||||
|
||||
expr_rows = len(self.input_expr_str_entries)
|
||||
expr_cols = len(self.input_expr_str_entries[0])
|
||||
shape_str = (
|
||||
f'({expr_rows}×{expr_cols})'
|
||||
if expr_rows != 1 or expr_cols != 1
|
||||
else '(Scalar)'
|
||||
)
|
||||
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=f'Expr {shape_str}')
|
||||
|
||||
if (
|
||||
len(self.input_expr_str_entries) == 1
|
||||
and len(self.input_expr_str_entries[0]) == 1
|
||||
):
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=self.input_expr_str_entries[0][0])
|
||||
else:
|
||||
grid = box.grid_flow(
|
||||
row_major=True,
|
||||
columns=len(self.input_expr_str_entries[0]),
|
||||
align=True,
|
||||
)
|
||||
for row in self.input_expr_str_entries:
|
||||
for entry in row:
|
||||
grid.label(text=entry)
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
|
@ -119,7 +214,7 @@ class ViewerNode(base.MaxwellSimNode):
|
|||
return
|
||||
|
||||
log.info('Printing to Console')
|
||||
data = self._compute_input('Any', kind=self.print_kind, optional=True)
|
||||
data = self._compute_input('Any', kind=self.console_print_kind, optional=True)
|
||||
|
||||
if isinstance(data, spux.SympyType):
|
||||
console.print(sp.pretty(data, use_unicode=True))
|
||||
|
|
|
@ -40,6 +40,8 @@ _max_e_socket_def = sockets.ExprSocketDef(
|
|||
)
|
||||
_offset_socket_def = sockets.ExprSocketDef(default_value=5, abs_min=2.5)
|
||||
|
||||
t_ps = sim_symbols.t(spu.picosecond)
|
||||
|
||||
|
||||
class TemporalShapeNode(base.MaxwellSimNode):
|
||||
"""Declare a source-time dependence for use in simulation source nodes."""
|
||||
|
@ -82,8 +84,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
default_steps=100,
|
||||
),
|
||||
'Envelope': sockets.ExprSocketDef(
|
||||
default_symbols=[sim_symbols.t],
|
||||
default_value=10 * sim_symbols.t.sp_symbol,
|
||||
default_symbols=[t_ps],
|
||||
default_value=10 * t_ps.sp_symbol,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -593,6 +593,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
ValueError: When referencing a socket that's meant to be directly referenced.
|
||||
"""
|
||||
kind_data_map = {
|
||||
ct.FlowKind.Capabilities: lambda: self.capabilities,
|
||||
ct.FlowKind.Value: lambda: self.value,
|
||||
ct.FlowKind.Array: lambda: self.array,
|
||||
ct.FlowKind.Func: lambda: self.lazy_func,
|
||||
|
|
|
@ -111,29 +111,38 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
bl_label = 'Expr'
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
# - Socket Interface
|
||||
####################
|
||||
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
|
||||
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
|
||||
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
|
||||
|
||||
# Symbols
|
||||
####################
|
||||
# - Symbols
|
||||
####################
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Expr
|
||||
)
|
||||
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
||||
|
||||
@property
|
||||
def symbols(self) -> set[sp.Symbol]:
|
||||
"""Current symbols as an unordered set."""
|
||||
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
|
||||
symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
|
||||
"""Sympy symbols as an unordered set."""
|
||||
return {sim_symbol.sp_symbol_matsym for sim_symbol in self.symbols}
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||
def sorted_symbols(self) -> list[sim_symbols.SimSymbol]:
|
||||
"""Current symbols as a sorted list."""
|
||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||
|
||||
# Unit
|
||||
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||
def sorted_sp_symbols(self) -> list[sp.Symbol | sp.MatrixSymbol]:
|
||||
"""Computes `sympy` symbols from `self.sorted_symbols`."""
|
||||
return [sym.sp_symbol_matsym for sym in self.sorted_symbols]
|
||||
|
||||
####################
|
||||
# - Units
|
||||
####################
|
||||
active_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_valid_units(),
|
||||
cb_depends_on={'physical_type'},
|
||||
|
@ -148,6 +157,29 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
]
|
||||
return []
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'active_unit'})
|
||||
def unit(self) -> spux.Unit | None:
|
||||
"""Gets the current active unit.
|
||||
|
||||
Returns:
|
||||
The current active `sympy` unit.
|
||||
|
||||
If the socket expression is unitless, this returns `None`.
|
||||
"""
|
||||
if self.active_unit is not None:
|
||||
return spux.unit_str_to_unit(self.active_unit)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def unit_factor(self) -> spux.Unit | None:
|
||||
return sp.Integer(1) if self.unit is None else self.unit
|
||||
|
||||
prev_unit: str | None = bl_cache.BLField(None)
|
||||
|
||||
####################
|
||||
# - UI Values
|
||||
####################
|
||||
# UI: Value
|
||||
## Expression
|
||||
raw_value_spstr: str = bl_cache.BLField('0.0')
|
||||
|
@ -186,6 +218,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
# UI: Info
|
||||
show_name_selector: bool = bl_cache.BLField(False)
|
||||
show_func_ui: bool = bl_cache.BLField(True)
|
||||
show_info_columns: bool = bl_cache.BLField(False)
|
||||
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
|
||||
|
@ -207,25 +240,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
def raw_max_sp(self) -> spux.SympyExpr:
|
||||
return self._parse_expr_str(self.raw_max_spstr)
|
||||
|
||||
####################
|
||||
# - Computed Unit
|
||||
####################
|
||||
@bl_cache.cached_bl_property(depends_on={'active_unit'})
|
||||
def unit(self) -> spux.Unit | None:
|
||||
"""Gets the current active unit.
|
||||
|
||||
Returns:
|
||||
The current active `sympy` unit.
|
||||
|
||||
If the socket expression is unitless, this returns `None`.
|
||||
"""
|
||||
if self.active_unit is not None:
|
||||
return spux.unit_str_to_unit(self.active_unit)
|
||||
|
||||
return None
|
||||
|
||||
prev_unit: str | None = bl_cache.BLField(None)
|
||||
|
||||
####################
|
||||
# - Prop-Change Callback
|
||||
####################
|
||||
|
@ -272,7 +286,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
raise ValueError(msg)
|
||||
|
||||
# Parse Symbols
|
||||
if expr.free_symbols and not expr.free_symbols.issubset(self.symbols):
|
||||
if expr.free_symbols and not expr.free_symbols.issubset(self.sp_symbols):
|
||||
msg = f'Tried to set expr {expr} with free symbols {expr.free_symbols}, which is incompatible with socket symbols {self.symbols}'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -320,7 +334,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
"""
|
||||
expr = sp.sympify(
|
||||
expr_spstr,
|
||||
locals={sym.name: sym for sym in self.symbols},
|
||||
locals={sym.name: sym.sp_symbol_matsym for sym in self.symbols},
|
||||
strict=False,
|
||||
convert_xor=True,
|
||||
).subs(spux.UNIT_BY_SYMBOL)
|
||||
|
@ -562,11 +576,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
if self.symbols:
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(
|
||||
self.sorted_symbols,
|
||||
spux.scale_to_unit(self.value, self.unit),
|
||||
self.sorted_sp_symbols,
|
||||
spux.strip_unit_system(self.value),
|
||||
'jax',
|
||||
),
|
||||
func_args=[spux.MathType.from_expr(sym) for sym in self.sorted_symbols],
|
||||
func_args=list(self.sorted_symbols),
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
@ -578,7 +592,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[
|
||||
self.physical_type if self.physical_type is not None else self.mathtype
|
||||
sim_symbols.SimSymbol.from_expr(
|
||||
sim_symbols.SimSymbolName.Constant, self.value, self.unit_factor
|
||||
)
|
||||
],
|
||||
supports_jax=True,
|
||||
)
|
||||
|
@ -597,8 +613,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
## -> NOTE: func_args must have the same symbol order as was lambdified.
|
||||
if self.symbols:
|
||||
return ct.ParamsFlow(
|
||||
func_args=self.sorted_symbols,
|
||||
symbols=self.symbols,
|
||||
func_args=[sym.sp_symbol_phy for sym in self.sorted_symbols],
|
||||
symbols=self.sorted_symbols,
|
||||
)
|
||||
|
||||
# Constant
|
||||
|
@ -618,24 +634,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||
"""
|
||||
output_sim_sym = (
|
||||
sim_symbols.SimSymbol(
|
||||
output_sym = sim_symbols.SimSymbol(
|
||||
sym_name=self.output_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
rows=self.size.rows,
|
||||
cols=self.size.cols,
|
||||
),
|
||||
)
|
||||
if self.symbols:
|
||||
return ct.InfoFlow(
|
||||
dims={sim_sym: None for sim_sym in self.active_symbols},
|
||||
output=output_sim_sym,
|
||||
)
|
||||
|
||||
# Constant
|
||||
return ct.InfoFlow(output=output_sim_sym)
|
||||
## -> The input SimSymbols become continuous dimensional indices.
|
||||
## -> All domain validity information is defined on the SimSymbol keys.
|
||||
if self.symbols:
|
||||
return ct.InfoFlow(
|
||||
dims={sym: None for sym in self.sorted_symbols},
|
||||
output=output_sym,
|
||||
)
|
||||
|
||||
# Constant
|
||||
## -> We only need the output symbol to describe the raw data.
|
||||
return ct.InfoFlow(output=output_sym)
|
||||
|
||||
####################
|
||||
# - FlowKind: Capabilities
|
||||
|
@ -645,6 +664,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
return ct.CapabilitiesFlow(
|
||||
socket_type=self.socket_type,
|
||||
active_kind=self.active_kind,
|
||||
allow_out_to_in={ct.FlowKind.Func: ct.FlowKind.Value},
|
||||
)
|
||||
|
||||
####################
|
||||
|
@ -795,7 +815,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
col = split.column()
|
||||
col.alignment = 'RIGHT'
|
||||
for sym in self.symbols:
|
||||
col.label(text=spux.pretty_symbol(sym))
|
||||
col.label(text=sym.def_label)
|
||||
|
||||
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
|
||||
"""Draw the socket body for a simple, uniform range of values between two values/expressions.
|
||||
|
@ -840,14 +860,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
Uses `draw_value` to draw the base UI
|
||||
"""
|
||||
if self.show_func_ui:
|
||||
# Output Name Selector
|
||||
## -> The name of the output
|
||||
col.prop(self, self.blfields['output_name'], text='')
|
||||
|
||||
# Physical Type Selector
|
||||
## -> Determines whether/which unit-dropdown will be shown.
|
||||
col.prop(self, self.blfields['physical_type'], text='')
|
||||
|
||||
# Non-Symbolic: Size/Mathtype Selector
|
||||
## -> Symbols imply str expr input.
|
||||
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
|
||||
|
@ -861,15 +873,30 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
self.draw_value(col)
|
||||
|
||||
# Physical Type Selector
|
||||
## -> Determines whether/which unit-dropdown will be shown.
|
||||
col.prop(self, self.blfields['physical_type'], text='')
|
||||
|
||||
# Symbol UI
|
||||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
## -> TODO
|
||||
|
||||
# Output Name Selector
|
||||
## -> The name of the output
|
||||
if self.show_name_selector:
|
||||
row = col.row()
|
||||
row.prop(self, self.blfields['output_name'], text='Name')
|
||||
|
||||
####################
|
||||
# - UI: InfoFlow
|
||||
####################
|
||||
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
||||
if self.active_kind == ct.FlowKind.Func and self.show_info_columns:
|
||||
"""Visualize the `InfoFlow` information passing through the socket."""
|
||||
if (
|
||||
self.active_kind == ct.FlowKind.Func
|
||||
and self.show_info_columns
|
||||
and self.is_linked
|
||||
):
|
||||
row = col.row()
|
||||
box = row.box()
|
||||
grid = box.grid_flow(
|
||||
|
@ -881,38 +908,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
# Dimensions
|
||||
for dim in info.dims:
|
||||
dim_idx = info.dims[dim]
|
||||
grid.label(text=dim.name_pretty)
|
||||
for dim_name_pretty, dim_label_info in info.dim_labels.items():
|
||||
grid.label(text=dim_name_pretty)
|
||||
if InfoDisplayCol.Length in self.info_columns:
|
||||
grid.label(text=str(len(dim_idx)))
|
||||
grid.label(text=dim_label_info['length'])
|
||||
if InfoDisplayCol.MathType in self.info_columns:
|
||||
grid.label(text=spux.MathType.to_str(dim_idx.mathtype))
|
||||
grid.label(text=dim_label_info['mathtype'])
|
||||
if InfoDisplayCol.Unit in self.info_columns:
|
||||
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
||||
grid.label(text=dim_label_info['unit'])
|
||||
|
||||
# Outputs
|
||||
grid.label(text=info.output.name_pretty)
|
||||
if InfoDisplayCol.Length in self.info_columns:
|
||||
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
||||
if InfoDisplayCol.MathType in self.info_columns:
|
||||
grid.label(
|
||||
text=(
|
||||
spux.MathType.to_str(info.output.mathtype)
|
||||
+ (
|
||||
'ˣ'.join(
|
||||
[
|
||||
unicode_superscript(out_axis)
|
||||
for out_axis in info.output.shape
|
||||
]
|
||||
)
|
||||
if info.output.shape
|
||||
else ''
|
||||
)
|
||||
)
|
||||
)
|
||||
grid.label(text=info.output.def_label)
|
||||
if InfoDisplayCol.Unit in self.info_columns:
|
||||
grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
|
||||
grid.label(text=info.output.unit_label)
|
||||
|
||||
|
||||
####################
|
||||
|
@ -926,7 +938,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
ct.FlowKind.Array,
|
||||
ct.FlowKind.Func,
|
||||
] = ct.FlowKind.Value
|
||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
|
||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName.Expr
|
||||
|
||||
# Socket Interface
|
||||
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
||||
|
@ -948,9 +960,15 @@ class ExprSocketDef(base.SocketDef):
|
|||
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
|
||||
|
||||
# UI
|
||||
show_name_selector: bool = False
|
||||
show_func_ui: bool = True
|
||||
show_info_columns: bool = False
|
||||
|
||||
@property
|
||||
def sp_symbols(self) -> set[sp.Symbol | sp.MatrixSymbol]:
|
||||
"""Default symbols as an unordered set."""
|
||||
return {sym.sp_symbol_matsym for sym in self.default_symbols}
|
||||
|
||||
####################
|
||||
# - Parse Unit and/or Physical Type
|
||||
####################
|
||||
|
@ -1149,6 +1167,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
raise ValueError(msg)
|
||||
|
||||
# Coerce from Infinite
|
||||
if isinstance(bound, spux.SympyType):
|
||||
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
|
||||
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
|
||||
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
|
||||
|
@ -1194,9 +1213,9 @@ class ExprSocketDef(base.SocketDef):
|
|||
def symbols_value(self) -> typ.Self:
|
||||
if (
|
||||
self.default_value.free_symbols
|
||||
and not self.default_value.free_symbols.issubset(self.symbols)
|
||||
and not self.default_value.free_symbols.issubset(self.sp_symbols)
|
||||
):
|
||||
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.symbols}'
|
||||
msg = f'Tried to set default value {self.default_value} with free symbols {self.default_value.free_symbols}, which is incompatible with socket symbols {self.sp_symbols}'
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
@ -1227,7 +1246,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
bl_socket.size = self.size
|
||||
bl_socket.mathtype = self.mathtype
|
||||
bl_socket.physical_type = self.physical_type
|
||||
bl_socket.active_symbols = self.symbols
|
||||
bl_socket.symbols = self.default_symbols
|
||||
|
||||
# FlowKind.Value
|
||||
## -> We must take units into account when setting bl_socket.value
|
||||
|
@ -1252,6 +1271,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
# UI
|
||||
bl_socket.show_func_ui = self.show_func_ui
|
||||
bl_socket.show_info_columns = self.show_info_columns
|
||||
bl_socket.show_name_selector = self.show_name_selector
|
||||
|
||||
# Info Draw
|
||||
bl_socket.use_info_draw = True
|
||||
|
|
|
@ -389,14 +389,15 @@ class BLField:
|
|||
|
||||
Reset by setting the descriptor to `Signal.ResetStrSearch`.
|
||||
"""
|
||||
cached_items = self.bl_prop_str_search.read_nonpersist(_self)
|
||||
if cached_items is not Signal.CacheNotReady:
|
||||
if cached_items is Signal.CacheEmpty:
|
||||
computed_items = self.str_cb(_self, context, edit_text)
|
||||
self.bl_prop_str_search.write_nonpersist(_self, computed_items)
|
||||
return computed_items
|
||||
return cached_items
|
||||
return []
|
||||
return self.str_cb(_self, context, edit_text)
|
||||
# cached_items = self.bl_prop_str_search.read_nonpersist(_self)
|
||||
# if cached_items is not Signal.CacheNotReady:
|
||||
# if cached_items is Signal.CacheEmpty:
|
||||
# computed_items = self.str_cb(_self, context, edit_text)
|
||||
# self.bl_prop_str_search.write_nonpersist(_self, computed_items)
|
||||
# return computed_items
|
||||
# return cached_items
|
||||
# return []
|
||||
|
||||
def safe_enum_cb(
|
||||
self, _self: bl_instance.BLInstance, context: bpy.types.Context
|
||||
|
|
|
@ -28,11 +28,13 @@ Attributes:
|
|||
|
||||
import enum
|
||||
import functools
|
||||
import sys
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
@ -144,6 +146,21 @@ class MathType(enum.StrEnum):
|
|||
complex: MathType.Complex,
|
||||
}[dtype]
|
||||
|
||||
@staticmethod
|
||||
def from_jax_array(data: jtyp.Shaped[jtyp.Array, '...']) -> type:
|
||||
"""Deduce the MathType corresponding to a JAX array.
|
||||
|
||||
We go about this by leveraging that:
|
||||
- `data` is of a homogeneous type.
|
||||
- `data.item(0)` returns a single element of the array w/pure-python type.
|
||||
|
||||
By combing this with `type()` and `MathType.from_pytype`, we can effectively deduce the `MathType` of the entire array with relative efficiency.
|
||||
|
||||
Notes:
|
||||
Should also work with numpy arrays.
|
||||
"""
|
||||
return MathType.from_pytype(type(data.item(0)))
|
||||
|
||||
@staticmethod
|
||||
def has_mathtype(obj: typ.Any) -> typ.Literal['pytype', 'expr'] | None:
|
||||
if isinstance(obj, bool | int | Fraction | float | complex):
|
||||
|
@ -173,6 +190,39 @@ class MathType(enum.StrEnum):
|
|||
MT.Complex: sp.Complexes,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def inf_finite(self) -> type:
|
||||
"""Opinionated finite representation of "infinity" within this `MathType`.
|
||||
|
||||
These are chosen using `sys.maxsize` and `sys.float_info`.
|
||||
As such, while not arbitrary, this "finite representation of infinity" certainly is opinionated.
|
||||
|
||||
**Note** that, in practice, most systems will have no trouble working with values that exceed those defined here.
|
||||
|
||||
Notes:
|
||||
Values should be presumed to vary by-platform, as the `sys` attributes may be influenced by CPU architecture, OS, runtime environment, etc. .
|
||||
|
||||
These values can be used directly in `jax` arrays, but at the cost of an overflow warning (in part because `jax` generally only allows the use of `float32`).
|
||||
In this case, the warning doesn't matter, as the value will be cast to `jnp.inf` anyway.
|
||||
|
||||
However, it's generally cleaner to directly use `jnp.inf` if infinite values must be defined in an array context.
|
||||
"""
|
||||
MT = MathType
|
||||
Z = MT.Integer
|
||||
R = MT.Integer
|
||||
return {
|
||||
MT.Integer: (-sys.maxsize, sys.maxsize),
|
||||
MT.Rational: (
|
||||
Fraction(Z.inf_finite[0], 1),
|
||||
Fraction(Z.inf_finite[1], 1),
|
||||
),
|
||||
MT.Real: -(sys.float_info.min, sys.float_info.max),
|
||||
MT.Complex: (
|
||||
complex(R.inf_finite[0], R.inf_finite[0]),
|
||||
complex(R.inf_finite[1], R.inf_finite[1]),
|
||||
),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def sp_symbol_a(self) -> type:
|
||||
MT = MathType
|
||||
|
@ -192,6 +242,10 @@ class MathType(enum.StrEnum):
|
|||
MathType.Complex: 'ℂ',
|
||||
}[value]
|
||||
|
||||
@property
|
||||
def label_pretty(self) -> str:
|
||||
return MathType.to_str(self)
|
||||
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return MathType.to_str(value)
|
||||
|
@ -819,14 +873,15 @@ def sp_to_str(sp_obj: SympyExpr) -> str:
|
|||
A string representing the expression for human use.
|
||||
_The string is not re-encodable to the expression._
|
||||
"""
|
||||
## TODO: A bool flag property that does a lot of find/replace to make it super pretty
|
||||
return _SYMPY_EXPR_PRINTER_STR.doprint(sp_obj)
|
||||
|
||||
|
||||
def pretty_symbol(sym: sp.Symbol) -> str:
|
||||
return f'{sym.name} ∈ ' + (
|
||||
'ℂ'
|
||||
if sym.is_complex
|
||||
else ('ℝ' if sym.is_real else ('ℤ' if sym.is_integer else '?'))
|
||||
'ℤ'
|
||||
if sym.is_integer
|
||||
else ('ℝ' if sym.is_real else ('ℂ' if sym.is_complex else '?'))
|
||||
)
|
||||
|
||||
|
||||
|
@ -1039,20 +1094,24 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.LumIntensity: spu.candela,
|
||||
PT.LumFlux: spu.candela * spu.steradian,
|
||||
PT.Illuminance: spu.candela / spu.meter**2,
|
||||
# Optics
|
||||
PT.OrdinaryWaveVector: terahertz,
|
||||
PT.AngularWaveVector: spu.radian * terahertz,
|
||||
}[self]
|
||||
|
||||
@functools.cached_property
|
||||
def valid_units(self) -> list[Unit]:
|
||||
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
|
||||
|
||||
Notes:
|
||||
The order in which valid units are declared is the exact same order that UI dropdowns display them.
|
||||
|
||||
**Altering the order of units breaks backwards compatibility**.
|
||||
"""
|
||||
PT = PhysicalType
|
||||
return {
|
||||
PT.NonPhysical: [None],
|
||||
# Global
|
||||
PT.Time: [
|
||||
femtosecond,
|
||||
spu.picosecond,
|
||||
femtosecond,
|
||||
spu.nanosecond,
|
||||
spu.microsecond,
|
||||
spu.millisecond,
|
||||
|
@ -1070,11 +1129,11 @@ class PhysicalType(enum.StrEnum):
|
|||
],
|
||||
PT.Freq: (
|
||||
_valid_freqs := [
|
||||
terahertz,
|
||||
spu.hertz,
|
||||
kilohertz,
|
||||
megahertz,
|
||||
gigahertz,
|
||||
terahertz,
|
||||
petahertz,
|
||||
exahertz,
|
||||
]
|
||||
|
@ -1083,10 +1142,10 @@ class PhysicalType(enum.StrEnum):
|
|||
# Cartesian
|
||||
PT.Length: (
|
||||
_valid_lens := [
|
||||
spu.micrometer,
|
||||
spu.nanometer,
|
||||
spu.picometer,
|
||||
spu.angstrom,
|
||||
spu.nanometer,
|
||||
spu.micrometer,
|
||||
spu.millimeter,
|
||||
spu.centimeter,
|
||||
spu.meter,
|
||||
|
@ -1102,24 +1161,24 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.Vel: [_unit / spu.second for _unit in _valid_lens],
|
||||
PT.Accel: [_unit / spu.second**2 for _unit in _valid_lens],
|
||||
PT.Mass: [
|
||||
spu.kilogram,
|
||||
spu.electron_rest_mass,
|
||||
spu.dalton,
|
||||
spu.microgram,
|
||||
spu.milligram,
|
||||
spu.gram,
|
||||
spu.kilogram,
|
||||
spu.metric_ton,
|
||||
],
|
||||
PT.Force: [
|
||||
spu.kg * spu.meter / spu.second**2,
|
||||
nanonewton,
|
||||
micronewton,
|
||||
nanonewton,
|
||||
millinewton,
|
||||
spu.newton,
|
||||
spu.kg * spu.meter / spu.second**2,
|
||||
],
|
||||
PT.Pressure: [
|
||||
millibar,
|
||||
spu.bar,
|
||||
millibar,
|
||||
spu.pascal,
|
||||
hectopascal,
|
||||
spu.atmosphere,
|
||||
|
@ -1129,8 +1188,8 @@ class PhysicalType(enum.StrEnum):
|
|||
],
|
||||
# Energy
|
||||
PT.Work: [
|
||||
spu.electronvolt,
|
||||
spu.joule,
|
||||
spu.electronvolt,
|
||||
],
|
||||
PT.Power: [
|
||||
spu.watt,
|
||||
|
@ -1194,18 +1253,17 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.Illuminance: [
|
||||
spu.candela / spu.meter**2,
|
||||
],
|
||||
# Optics
|
||||
PT.OrdinaryWaveVector: _valid_freqs,
|
||||
PT.AngularWaveVector: [spu.radian * _unit for _unit in _valid_freqs],
|
||||
}[self]
|
||||
|
||||
@staticmethod
|
||||
def from_unit(unit: Unit) -> list[Unit]:
|
||||
def from_unit(unit: Unit, optional: bool = False) -> list[Unit] | None:
|
||||
for physical_type in list(PhysicalType):
|
||||
if unit in physical_type.valid_units:
|
||||
return physical_type
|
||||
## TODO: Optimize
|
||||
|
||||
if optional:
|
||||
return None
|
||||
msg = f'Could not determine PhysicalType for {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -1400,20 +1458,9 @@ def sympy_to_python(
|
|||
####################
|
||||
# - Convert to Unit System
|
||||
####################
|
||||
def convert_to_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None
|
||||
def strip_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None = None
|
||||
) -> SympyExpr:
|
||||
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
|
||||
if unit_system is None:
|
||||
return sp_obj
|
||||
|
||||
return spu.convert_to(
|
||||
sp_obj,
|
||||
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
|
||||
)
|
||||
|
||||
|
||||
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
|
||||
"""Strip units occurring in the given unit system from the expression.
|
||||
|
||||
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
|
||||
|
@ -1427,6 +1474,19 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> Symp
|
|||
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
|
||||
|
||||
|
||||
def convert_to_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None
|
||||
) -> SympyExpr:
|
||||
"""Convert an expression to the units of a given unit system."""
|
||||
if unit_system is None:
|
||||
return sp_obj
|
||||
|
||||
return spu.convert_to(
|
||||
sp_obj,
|
||||
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
|
||||
)
|
||||
|
||||
|
||||
def scale_to_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
|
||||
) -> int | float | complex | tuple | jax.Array:
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
import enum
|
||||
import functools
|
||||
import time
|
||||
import typing as typ
|
||||
|
||||
import jax
|
||||
|
@ -34,7 +33,7 @@ import seaborn as sns
|
|||
from blender_maxwell import contracts as ct
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
mplstyle.use('fast') ## TODO: Does this do anything?
|
||||
# mplstyle.use('fast') ## TODO: Does this do anything?
|
||||
sns.set_theme()
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
@ -59,6 +58,9 @@ class Colormap(enum.StrEnum):
|
|||
Viridis = enum.auto()
|
||||
Grayscale = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
|
@ -139,7 +141,9 @@ def rgba_image_from_2d_map(
|
|||
####################
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
||||
fig = matplotlib.figure.Figure(figsize=[width_inches, height_inches], dpi=dpi)
|
||||
fig = matplotlib.figure.Figure(
|
||||
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
|
||||
)
|
||||
canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
|
||||
ax = fig.add_subplot()
|
||||
|
||||
|
@ -152,66 +156,53 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
|||
# - Plotters
|
||||
####################
|
||||
# (ℤ) -> ℝ
|
||||
def plot_box_plot_1d(
|
||||
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
|
||||
x_sym, y_sym = list(data.keys())
|
||||
|
||||
ax.boxplot([data])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.boxplot([data[y_sym]])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
def plot_bar(data, ax: mpl_ax.Axis) -> None:
|
||||
x_sym, heights_sym = list(data.keys())
|
||||
|
||||
p = ax.bar(info.dims[x_sym], data)
|
||||
p = ax.bar(data[x_sym], data[heights_sym])
|
||||
ax.bar_label(p, label_type='center')
|
||||
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.set_xlabel(heights_sym.plot_label)
|
||||
|
||||
|
||||
# (ℝ) -> ℝ
|
||||
def plot_curve_2d(
|
||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
|
||||
x_sym, y_sym = list(data.keys())
|
||||
|
||||
ax.plot(info.dims[x_sym].realize_array.values, data)
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.plot(data[x_sym], data[y_sym])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
def plot_points_2d(
|
||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
|
||||
x_sym, y_sym = list(data.keys())
|
||||
|
||||
ax.scatter(x_sym.realize_array.values, data)
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.scatter(data[x_sym], data[y_sym])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
# (ℝ, ℤ) -> ℝ
|
||||
def plot_curves_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.output
|
||||
def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
|
||||
x_sym, label_sym, y_sym = list(data.keys())
|
||||
|
||||
for i, category in enumerate(info.dims[info.last_dim]):
|
||||
ax.plot(info.dims[x_sym], data[:, i], label=category)
|
||||
for i, label in enumerate(data[label_sym]):
|
||||
ax.plot(data[x_sym], data[y_sym][:, i], label=label)
|
||||
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.legend()
|
||||
|
@ -220,12 +211,10 @@ def plot_curves_2d(
|
|||
def plot_filled_curves_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.output
|
||||
x_sym, _, y_sym = list(data.keys())
|
||||
|
||||
shared_x_idx = info.dims[info.last_dim]
|
||||
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name_pretty}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.legend()
|
||||
|
@ -235,11 +224,9 @@ def plot_filled_curves_2d(
|
|||
def plot_heatmap_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.last_dim
|
||||
c_sym = info.output
|
||||
x_sym, y_sym, c_sym = list(data.keys())
|
||||
|
||||
heatmap = ax.imshow(data, aspect='equal', interpolation='none')
|
||||
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
|
||||
ax.figure.colorbar(heatmap, cax=ax)
|
||||
|
||||
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
|
||||
|
|
|
@ -99,6 +99,7 @@ class TypeID(enum.StrEnum):
|
|||
SympyType: str = '!type=sympytype'
|
||||
SympyExpr: str = '!type=sympyexpr'
|
||||
SocketDef: str = '!type=socketdef'
|
||||
SimSymbol: str = '!type=simsymbol'
|
||||
ManagedObj: str = '!type=managedobj'
|
||||
|
||||
|
||||
|
@ -161,11 +162,12 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
|
|||
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
|
||||
|
||||
if hasattr(_type, 'parse_as_msgspec') and (
|
||||
is_representation(obj) and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj]
|
||||
is_representation(obj)
|
||||
and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj, TypeID.SimSymbol]
|
||||
):
|
||||
return _type.parse_as_msgspec(obj)
|
||||
|
||||
msg = f'Can\'t decode "{obj}" to type {type(obj)}'
|
||||
msg = f'can\'t decode "{obj}" to type {type(obj)}'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
|
|
|
@ -14,36 +14,60 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import string
|
||||
import sys
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
|
||||
import jaxtyping as jtyp
|
||||
import pydantic as pyd
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from . import extra_sympy_units as spux
|
||||
from . import logger, serialize
|
||||
|
||||
int_min = -(2**64)
|
||||
int_max = 2**64
|
||||
float_min = sys.float_info.min
|
||||
float_max = sys.float_info.max
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
def unicode_superscript(n: int) -> str:
|
||||
"""Transform an integer into its unicode-based superscript character."""
|
||||
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
|
||||
|
||||
|
||||
####################
|
||||
# - Simulation Symbol Names
|
||||
####################
|
||||
_l = ''
|
||||
_it_lower = iter(string.ascii_lowercase)
|
||||
|
||||
|
||||
class SimSymbolName(enum.StrEnum):
|
||||
# Lower
|
||||
LowerA = enum.auto()
|
||||
LowerB = enum.auto()
|
||||
LowerC = enum.auto()
|
||||
LowerD = enum.auto()
|
||||
LowerI = enum.auto()
|
||||
LowerT = enum.auto()
|
||||
LowerX = enum.auto()
|
||||
LowerY = enum.auto()
|
||||
LowerZ = enum.auto()
|
||||
# Generic
|
||||
Constant = enum.auto()
|
||||
Expr = enum.auto()
|
||||
Data = enum.auto()
|
||||
|
||||
# Ascii Letters
|
||||
while True:
|
||||
try:
|
||||
globals()['_l'] = next(globals()['_it_lower'])
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
|
||||
locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
|
||||
|
||||
# Greek Letters
|
||||
LowerTheta = enum.auto()
|
||||
LowerPhi = enum.auto()
|
||||
|
||||
# Fields
|
||||
Ex = enum.auto()
|
||||
|
@ -64,18 +88,15 @@ class SimSymbolName(enum.StrEnum):
|
|||
Wavelength = enum.auto()
|
||||
Frequency = enum.auto()
|
||||
|
||||
Flux = enum.auto()
|
||||
|
||||
PermXX = enum.auto()
|
||||
PermYY = enum.auto()
|
||||
PermZZ = enum.auto()
|
||||
|
||||
Flux = enum.auto()
|
||||
|
||||
DiffOrderX = enum.auto()
|
||||
DiffOrderY = enum.auto()
|
||||
|
||||
# Generic
|
||||
Expr = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
|
@ -109,17 +130,21 @@ class SimSymbolName(enum.StrEnum):
|
|||
@property
|
||||
def name(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
# Lower
|
||||
SSN.LowerA: 'a',
|
||||
SSN.LowerB: 'b',
|
||||
SSN.LowerC: 'c',
|
||||
SSN.LowerD: 'd',
|
||||
SSN.LowerI: 'i',
|
||||
SSN.LowerT: 't',
|
||||
SSN.LowerX: 'x',
|
||||
SSN.LowerY: 'y',
|
||||
SSN.LowerZ: 'z',
|
||||
return (
|
||||
# Ascii Letters
|
||||
{SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
|
||||
| {
|
||||
SSN[f'Upper{letter.upper()}']: letter.upper()
|
||||
for letter in string.ascii_lowercase
|
||||
}
|
||||
| {
|
||||
# Generic
|
||||
SSN.Constant: 'constant',
|
||||
SSN.Expr: 'expr',
|
||||
SSN.Data: 'data',
|
||||
# Greek Letters
|
||||
SSN.LowerTheta: 'theta',
|
||||
SSN.LowerPhi: 'phi',
|
||||
# Fields
|
||||
SSN.Ex: 'Ex',
|
||||
SSN.Ey: 'Ey',
|
||||
|
@ -136,22 +161,35 @@ class SimSymbolName(enum.StrEnum):
|
|||
# Optics
|
||||
SSN.Wavelength: 'wl',
|
||||
SSN.Frequency: 'freq',
|
||||
SSN.Flux: 'flux',
|
||||
SSN.PermXX: 'eps_xx',
|
||||
SSN.PermYY: 'eps_yy',
|
||||
SSN.PermZZ: 'eps_zz',
|
||||
SSN.Flux: 'flux',
|
||||
SSN.DiffOrderX: 'order_x',
|
||||
SSN.DiffOrderY: 'order_y',
|
||||
# Generic
|
||||
SSN.Expr: 'expr',
|
||||
}[self]
|
||||
}
|
||||
)[self]
|
||||
|
||||
@property
|
||||
def name_pretty(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
# Generic
|
||||
# Greek Letters
|
||||
SSN.LowerTheta: 'θ',
|
||||
SSN.LowerPhi: 'φ',
|
||||
# Fields
|
||||
SSN.Etheta: 'Eθ',
|
||||
SSN.Ephi: 'Eφ',
|
||||
SSN.Hr: 'Hr',
|
||||
SSN.Htheta: 'Hθ',
|
||||
SSN.Hphi: 'Hφ',
|
||||
# Optics
|
||||
SSN.Wavelength: 'λ',
|
||||
SSN.Frequency: '𝑓',
|
||||
SSN.PermXX: 'ε_xx',
|
||||
SSN.PermYY: 'ε_yy',
|
||||
SSN.PermZZ: 'ε_zz',
|
||||
}.get(self, self.name)
|
||||
|
||||
|
||||
|
@ -173,8 +211,7 @@ def mk_interval(
|
|||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class SimSymbol:
|
||||
class SimSymbol(pyd.BaseModel):
|
||||
"""A declarative representation of a symbolic variable.
|
||||
|
||||
`sympy`'s symbols aren't quite flexible enough for our needs: The symbols that we're transporting often need exact domain information, an associated unit dimension, and a great deal of determinism in checks thereof.
|
||||
|
@ -183,6 +220,8 @@ class SimSymbol:
|
|||
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
|
||||
"""
|
||||
|
||||
model_config = pyd.ConfigDict(frozen=True)
|
||||
|
||||
sym_name: SimSymbolName
|
||||
mathtype: spux.MathType = spux.MathType.Real
|
||||
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
|
||||
|
@ -191,6 +230,9 @@ class SimSymbol:
|
|||
## -> 'None' indicates that no particular unit has yet been chosen.
|
||||
## -> Not exposed in the UI; must be set some other way.
|
||||
unit: spux.Unit | None = None
|
||||
## -> TODO: We currently allowing units that don't match PhysicalType
|
||||
## -> -- In particular, NonPhysical w/units means "unknown units".
|
||||
## -> -- This is essential for the Scientific Constant Node.
|
||||
|
||||
# Size
|
||||
## -> All SimSymbol sizes are "2D", but interpreted by convention.
|
||||
|
@ -205,43 +247,76 @@ class SimSymbol:
|
|||
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
|
||||
## -> See self.domain.
|
||||
## -> We have to deconstruct symbolic interval semantics a bit for UI.
|
||||
is_constant: bool = False
|
||||
interval_finite_z: tuple[int, int] = (0, 1)
|
||||
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
|
||||
interval_finite_re: tuple[float, float] = (0, 1)
|
||||
interval_finite_re: tuple[float, float] = (0.0, 1.0)
|
||||
interval_inf: tuple[bool, bool] = (True, True)
|
||||
interval_closed: tuple[bool, bool] = (False, False)
|
||||
|
||||
interval_finite_im: tuple[float, float] = (0, 1)
|
||||
interval_finite_im: tuple[float, float] = (0.0, 1.0)
|
||||
interval_inf_im: tuple[bool, bool] = (True, True)
|
||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
# - Labels
|
||||
####################
|
||||
@property
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
"""Usable name for the symbol."""
|
||||
return self.sym_name.name
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def name_pretty(self) -> str:
|
||||
"""Pretty (possibly unicode) name for the thing."""
|
||||
return self.sym_name.name_pretty
|
||||
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def mathtype_size_label(self) -> str:
|
||||
"""Pretty label that shows both mathtype and size."""
|
||||
return f'{self.mathtype.label_pretty}' + (
|
||||
'ˣ'.join([unicode_superscript(out_axis) for out_axis in self.shape])
|
||||
if self.shape
|
||||
else ''
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def unit_label(self) -> str:
|
||||
"""Pretty unit label, which is an empty string when there is no unit."""
|
||||
return spux.sp_to_str(self.unit) if self.unit is not None else ''
|
||||
|
||||
@functools.cached_property
|
||||
def def_label(self) -> str:
|
||||
"""Pretty definition label, exposing the symbol definition."""
|
||||
return f'{self.name_pretty} | {self.unit_label} ∈ {self.mathtype_size_label}'
|
||||
## TODO: Domain of validity from self.domain?
|
||||
|
||||
@functools.cached_property
|
||||
def plot_label(self) -> str:
|
||||
"""Pretty plot-oriented label."""
|
||||
return f'{self.name_pretty}' + (
|
||||
f'({self.unit})' if self.unit is not None else ''
|
||||
)
|
||||
|
||||
@property
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@functools.cached_property
|
||||
def unit_factor(self) -> spux.SympyExpr:
|
||||
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
||||
return self.unit if self.unit is not None else sp.S(1)
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def size(self) -> tuple[int, ...] | None:
|
||||
return {
|
||||
(1, 1): spux.NumberSize1D.Scalar,
|
||||
(2, 1): spux.NumberSize1D.Vec2,
|
||||
(3, 1): spux.NumberSize1D.Vec3,
|
||||
(4, 1): spux.NumberSize1D.Vec4,
|
||||
}.get((self.rows, self.cols))
|
||||
|
||||
@functools.cached_property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
match (self.rows, self.cols):
|
||||
case (1, 1):
|
||||
|
@ -253,7 +328,12 @@ class SimSymbol:
|
|||
case (_, _):
|
||||
return (self.rows, self.cols)
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def shape_len(self) -> spux.SympyExpr:
|
||||
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
||||
return len(self.shape)
|
||||
|
||||
@functools.cached_property
|
||||
def domain(self) -> sp.Interval | sp.Set:
|
||||
"""Return the scalar domain of valid values for each element of the symbol.
|
||||
|
||||
|
@ -303,11 +383,31 @@ class SimSymbol:
|
|||
),
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def valid_domain_value(self) -> spux.SympyExpr:
|
||||
"""A single value guaranteed to be conformant to this `SimSymbol` and within `self.domain`."""
|
||||
match (self.domain.start.is_finite, self.domain.end.is_finite):
|
||||
case (True, True):
|
||||
if self.mathtype is spux.MathType.Integer:
|
||||
return (self.domain.start + self.domain.end) // 2
|
||||
return (self.domain.start + self.domain.end) / 2
|
||||
|
||||
case (True, False):
|
||||
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||
return self.domain.start + one
|
||||
|
||||
case (False, True):
|
||||
one = sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||
return self.domain.end - one
|
||||
|
||||
case (False, False):
|
||||
return sp.S(self.mathtype.coerce_compatible_pyobj(-1))
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def sp_symbol(self) -> sp.Symbol:
|
||||
@functools.cached_property
|
||||
def sp_symbol(self) -> sp.Symbol | sp.ImmutableMatrix:
|
||||
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
|
||||
|
||||
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
|
||||
|
@ -352,7 +452,82 @@ class SimSymbol:
|
|||
elif self.domain.right <= 0:
|
||||
mathtype_kwargs |= {'negative': True}
|
||||
|
||||
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
|
||||
# Scalar: Return Symbol
|
||||
if self.rows == 1 and self.cols == 1:
|
||||
return sp.Symbol(self.sym_name.name, **mathtype_kwargs)
|
||||
|
||||
# Vector|Matrix: Return Matrix of Symbols
|
||||
## -> MatrixSymbol doesn't support assumptions.
|
||||
## -> This little construction does.
|
||||
return sp.ImmutableMatrix(
|
||||
[
|
||||
[
|
||||
sp.Symbol(self.sym_name.name + f'_{row}{col}', **mathtype_kwargs)
|
||||
for col in range(self.cols)
|
||||
]
|
||||
for row in range(self.rows)
|
||||
]
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def sp_symbol_matsym(self) -> sp.Symbol | sp.MatrixSymbol:
|
||||
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`, w/variable shape support.
|
||||
|
||||
To preserve as many assumptions as possible, `self.sp_symbol` returns a matrix of individual `sp.Symbol`s whenever the `SimSymbol` is non-scalar.
|
||||
However, this isn't always the most useful representation: For example, if the intention is to use a shaped symbolic variable as an argument to `sympy.lambdify()`, one would have to flatten each individual `sp.Symbol` and pass each matrix element as a single element, greatly complicating things like broadcasting.
|
||||
|
||||
For this reason, this property is provided.
|
||||
Whenever the `SimSymbol` is scalar, it works identically to `self.sp_symbol`.
|
||||
However, when the `SimSymbol` is shaped, an appropriate `sp.MatrixSymbol` is returned instead.
|
||||
|
||||
Notes:
|
||||
`sp.MatrixSymbol` doesn't support assumptions.
|
||||
As such, things like deduction of `MathType` from expressions involving a matrix symbol simply won't work.
|
||||
"""
|
||||
if self.shape_len == 0:
|
||||
return self.sp_symbol
|
||||
return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
|
||||
|
||||
@functools.cached_property
|
||||
def sp_symbol_phy(self) -> spux.SympyExpr:
|
||||
"""Physical symbol containing `self.sp_symbol` multiplied by `self.unit`."""
|
||||
return self.sp_symbol * self.unit_factor
|
||||
|
||||
@functools.cached_property
|
||||
def expr_info(self) -> dict[str, typ.Any]:
|
||||
"""Generate keyword arguments for an ExprSocket, whose output values will be guaranteed to conform to this `SimSymbol`.
|
||||
|
||||
Notes:
|
||||
Before use, `active_kind=ct.FlowKind.Range` can be added to make the `ExprSocket`.
|
||||
|
||||
Default values are set for both `Value` and `Range`.
|
||||
To this end, `self.domain` is used.
|
||||
|
||||
Since `ExprSocketDef` allows the use of infinite bounds for `default_min` and `default_max`, we defer the decision of how to treat finite-fallback to the `ExprSocketDef`.
|
||||
"""
|
||||
if self.size is not None:
|
||||
if self.unit in self.physical_type.valid_units:
|
||||
return {
|
||||
'output_name': self.sym_name,
|
||||
# Socket Interface
|
||||
'size': self.size,
|
||||
'mathtype': self.mathtype,
|
||||
'physical_type': self.physical_type,
|
||||
# Defaults: Units
|
||||
'default_unit': self.unit,
|
||||
'default_symbols': [],
|
||||
# Defaults: FlowKind.Value
|
||||
'default_value': self.conform(
|
||||
self.valid_domain_value, strip_unit=True
|
||||
),
|
||||
# Defaults: FlowKind.Range
|
||||
'default_min': self.domain.start,
|
||||
'default_max': self.domain.end,
|
||||
}
|
||||
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
|
||||
raise NotImplementedError(msg)
|
||||
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its size ({self.rows} by {self.cols}) is incompatible with ExprSocket (SimSymbol={self})'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
####################
|
||||
# - Operations
|
||||
|
@ -373,7 +548,7 @@ class SimSymbol:
|
|||
cols=get_attr('cols'),
|
||||
interval_finite_z=get_attr('interval_finite_z'),
|
||||
interval_finite_q=get_attr('interval_finite_q'),
|
||||
interval_finite_re=get_attr('interval_finite_q'),
|
||||
interval_finite_re=get_attr('interval_finite_re'),
|
||||
interval_inf=get_attr('interval_inf'),
|
||||
interval_closed=get_attr('interval_closed'),
|
||||
interval_finite_im=get_attr('interval_finite_im'),
|
||||
|
@ -381,24 +556,199 @@ class SimSymbol:
|
|||
interval_closed_im=get_attr('interval_closed_im'),
|
||||
)
|
||||
|
||||
def set_finite_domain( # noqa: PLR0913
|
||||
self,
|
||||
start: int | float,
|
||||
end: int | float,
|
||||
start_closed: bool = True,
|
||||
end_closed: bool = True,
|
||||
start_im: bool = float,
|
||||
end_im: bool = float,
|
||||
start_closed_im: bool = True,
|
||||
end_closed_im: bool = True,
|
||||
) -> typ.Self:
|
||||
"""Update the symbol with a finite range."""
|
||||
closed_re = (start_closed, end_closed)
|
||||
closed_im = (start_closed_im, end_closed_im)
|
||||
match self.mathtype:
|
||||
case spux.MathType.Integer:
|
||||
return self.update(
|
||||
interval_finite_z=(start, end),
|
||||
interval_inf=(False, False),
|
||||
interval_closed=closed_re,
|
||||
)
|
||||
case spux.MathType.Rational:
|
||||
return self.update(
|
||||
interval_finite_q=(start, end),
|
||||
interval_inf=(False, False),
|
||||
interval_closed=closed_re,
|
||||
)
|
||||
case spux.MathType.Real:
|
||||
return self.update(
|
||||
interval_finite_re=(start, end),
|
||||
interval_inf=(False, False),
|
||||
interval_closed=closed_re,
|
||||
)
|
||||
case spux.MathType.Complex:
|
||||
return self.update(
|
||||
interval_finite_re=(start, end),
|
||||
interval_finite_im=(start_im, end_im),
|
||||
interval_inf=(False, False),
|
||||
interval_closed=closed_re,
|
||||
interval_closed_im=closed_im,
|
||||
)
|
||||
|
||||
def set_size(self, rows: int, cols: int) -> typ.Self:
|
||||
return self.update(rows=rows, cols=cols)
|
||||
|
||||
def conform(
|
||||
self, sp_obj: spux.SympyType, strip_unit: bool = False
|
||||
) -> spux.SympyType:
|
||||
"""Conform a sympy object to the properties of this `SimSymbol`, if possible.
|
||||
|
||||
To achieve this, a number of operations may be performed:
|
||||
|
||||
- **Unit Conversion**: If the object has no units, but should, multiply by `self.unit`. If the object has units, but shouldn't, strip them. Otherwise, convert its unit to `self.unit`.
|
||||
- **Broadcast Expansion**: If the object is a scalar, but the `SimSymbol` is shaped, then an `sp.ImmutableMatrix` is returned with the scalar at each position.
|
||||
|
||||
Returns:
|
||||
A transformed sympy object guaranteed usable as a particular value of this `SimSymbol` variable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the units of `sp_obj` can't be cleanly converted to `self.unit`.
|
||||
"""
|
||||
res = sp_obj
|
||||
|
||||
# Unit Conversion
|
||||
match (spux.uses_units(sp_obj), self.unit is not None):
|
||||
case (True, True):
|
||||
res = spux.scale_to_unit(sp_obj, self.unit) * self.unit
|
||||
|
||||
case (False, True):
|
||||
res = sp_obj * self.unit
|
||||
|
||||
case (True, False):
|
||||
res = spux.strip_unit_system(sp_obj)
|
||||
|
||||
if strip_unit:
|
||||
res = spux.strip_unit_system(sp_obj)
|
||||
|
||||
# Broadcast Expansion
|
||||
if self.rows > 1 or self.cols > 1 and not isinstance(res, spux.MatrixBase):
|
||||
res = sp_obj * sp.ImmutableMatrix.ones(self.rows, self.cols)
|
||||
|
||||
return res
|
||||
|
||||
def scale(
|
||||
self, sp_obj: spux.SympyType, use_jax_array: bool = True
|
||||
) -> int | float | complex | jtyp.Inexact[jtyp.Array, '...']:
|
||||
"""Remove all symbolic elements from the conformed `sp_obj`, preparing it for use in contexts that don't support unrealized symbols.
|
||||
|
||||
On top of `self.conform()`, a number of operations are performed.
|
||||
|
||||
- **Unit Stripping**: The `self.unit` of the expression returned by `self.conform()` will be stripped.
|
||||
- **Sympy to Python**: The now symbol-less expression will be converted to either a pure Python type, or to a `jax` array (if `use_jax_array` is set).
|
||||
|
||||
Notes:
|
||||
When creating numerical functions of expressions using `.lambdify`, `self.scale()` **must be used** in place of `self.conform()` before the parameterized expression is used.
|
||||
|
||||
Returns:
|
||||
A "raw" (pure Python / jax array) type guaranteed usable as a particular **numerical** value of this `SymSymbol` variable.
|
||||
"""
|
||||
# Conform
|
||||
res = self.conform(sp_obj)
|
||||
|
||||
# Strip Units
|
||||
res = spux.scale_to_unit(sp_obj, self.unit)
|
||||
|
||||
# Sympy to Python
|
||||
res = spux.sympy_to_python(res, use_jax_array=use_jax_array)
|
||||
|
||||
return res # noqa: RET504
|
||||
|
||||
@staticmethod
|
||||
def from_expr(
|
||||
sym_name: SimSymbolName,
|
||||
expr: spux.SympyExpr,
|
||||
unit_expr: spux.SympyExpr,
|
||||
) -> typ.Self:
|
||||
"""Deduce a `SimSymbol` that matches the output of a given expression (and unit expression).
|
||||
|
||||
This is an essential method, allowing for the ded
|
||||
|
||||
Notes:
|
||||
`PhysicalType` **cannot be set** from an expression in the generic sense.
|
||||
Therefore, the trick of using `NonPhysical` with non-`None` unit to denote unknown `PhysicalType` is used in the output.
|
||||
|
||||
All intervals are kept at their defaults.
|
||||
|
||||
Parameters:
|
||||
sym_name: The `SimSymbolName` to set to the resulting symbol.
|
||||
expr: The unit-aware expression to parse and encapsulate as a symbol.
|
||||
unit_expr: A dimensional analysis expression (set to `1` to make the resulting symbol unitless).
|
||||
Fundamentally, units are just the variables of scalar terms.
|
||||
'1' for unitless terms are, in the dimanyl sense, constants.
|
||||
|
||||
Doing it like this may be a little messy, but is accurate.
|
||||
|
||||
Returns:
|
||||
A fresh new `SimSymbol` that tries to match the given expression (and unit expression) well enough to be usable in place of it.
|
||||
"""
|
||||
# MathType from Expr Assumptions
|
||||
## -> All input symbols have assumptions, because we are very pedantic.
|
||||
## -> Therefore, we should be able to reconstruct the MathType.
|
||||
mathtype = spux.MathType.from_expr(expr)
|
||||
|
||||
# PhysicalType as "NonPhysical"
|
||||
## -> 'unit' still applies - but we can't guarantee a PhysicalType will.
|
||||
## -> Therefore, this is what we gotta do.
|
||||
physical_type = spux.PhysicalType.NonPhysical
|
||||
|
||||
# Rows/Cols from Expr (if Matrix)
|
||||
rows, cols = expr.shape if isinstance(expr, sp.MatrixBase) else (1, 1)
|
||||
|
||||
return SimSymbol(
|
||||
sym_name=self.sym_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
sym_name=sym_name,
|
||||
mathtype=mathtype,
|
||||
physical_type=physical_type,
|
||||
unit=unit_expr if unit_expr != 1 else None,
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
interval_finite_z=self.interval_finite_z,
|
||||
interval_finite_q=self.interval_finite_q,
|
||||
interval_finite_re=self.interval_finite_re,
|
||||
interval_inf=self.interval_inf,
|
||||
interval_closed=self.interval_closed,
|
||||
interval_finite_im=self.interval_finite_im,
|
||||
interval_inf_im=self.interval_inf_im,
|
||||
interval_closed_im=self.interval_closed_im,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Serialization
|
||||
####################
|
||||
def dump_as_msgspec(self) -> serialize.NaiveRepresentation:
|
||||
"""Transforms this `SimSymbol` into an object that can be natively serialized by `msgspec`.
|
||||
|
||||
Notes:
|
||||
Makes use of `pydantic.BaseModel.model_dump()` to cast any special fields into a serializable format.
|
||||
If this method is failing, check that `pydantic` can actually cast all the fields in your model.
|
||||
|
||||
Returns:
|
||||
A particular `list`, with two elements:
|
||||
|
||||
1. The `serialize`-provided "Type Identifier", to differentiate this list from generic list.
|
||||
2. A dictionary containing simple Python types, as cast by `pydantic`.
|
||||
"""
|
||||
return [serialize.TypeID.SimSymbol, self.__class__.__name__, self.model_dump()]
|
||||
|
||||
@staticmethod
|
||||
def parse_as_msgspec(obj: serialize.NaiveRepresentation) -> typ.Self:
|
||||
"""Transforms an object made by `self.dump_as_msgspec()` into an instance of `SimSymbol`.
|
||||
|
||||
Notes:
|
||||
The method presumes that the deserialized object produced by `msgspec` perfectly matches the object originally created by `self.dump_as_msgspec()`.
|
||||
|
||||
This is a **mostly robust** presumption, as `pydantic` attempts to be quite consistent in how to interpret types with almost identical semantics.
|
||||
Still, yet-unknown edge cases may challenge these presumptions.
|
||||
|
||||
Returns:
|
||||
A new instance of `SimSymbol`, initialized using the `model_dump()` dictionary.
|
||||
"""
|
||||
return SimSymbol(**obj[2])
|
||||
|
||||
|
||||
####################
|
||||
# - Common Sim Symbols
|
||||
|
@ -453,14 +803,10 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
Wavelength = enum.auto()
|
||||
Frequency = enum.auto()
|
||||
|
||||
DiffOrderX = enum.auto()
|
||||
DiffOrderY = enum.auto()
|
||||
|
||||
Flux = enum.auto()
|
||||
|
||||
WaveVecX = enum.auto()
|
||||
WaveVecY = enum.auto()
|
||||
WaveVecZ = enum.auto()
|
||||
DiffOrderX = enum.auto()
|
||||
DiffOrderY = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
|
@ -549,10 +895,10 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
if eh == 'e'
|
||||
else spux.PhysicalType.HField,
|
||||
unit=unit,
|
||||
interval_finite_re=(0, sys.float_info.max),
|
||||
interval_finite_re=(0, float_max),
|
||||
interval_inf_re=(False, True),
|
||||
interval_closed_re=(True, False),
|
||||
interval_finite_im=(sys.float_info.min, sys.float_info.max),
|
||||
interval_finite_im=(float_min, float_max),
|
||||
interval_inf_im=(True, True),
|
||||
)
|
||||
|
||||
|
@ -575,7 +921,7 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
sym_name=self.name,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
unit=unit,
|
||||
interval_finite_re=(0, sys.float_info.max),
|
||||
interval_finite_re=(0, float_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
|
@ -592,19 +938,13 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
CSS.FieldHr: sym_field('h'),
|
||||
CSS.FieldHtheta: sym_field('h'),
|
||||
CSS.FieldHphi: sym_field('h'),
|
||||
CSS.Flux: SimSymbol(
|
||||
sym_name=SimSymbolName.Flux,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Power,
|
||||
unit=unit,
|
||||
),
|
||||
# Optics
|
||||
CSS.Wavelength: SimSymbol(
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Length,
|
||||
unit=unit,
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
interval_finite=(0, float_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
|
@ -613,10 +953,30 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
unit=unit,
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
interval_finite=(0, float_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.Flux: SimSymbol(
|
||||
sym_name=SimSymbolName.Flux,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Power,
|
||||
unit=unit,
|
||||
),
|
||||
CSS.DiffOrderX: SimSymbol(
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Integer,
|
||||
interval_finite=(int_min, int_max),
|
||||
interval_inf=(True, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.DiffOrderY: SimSymbol(
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Integer,
|
||||
interval_finite=(int_min, int_max),
|
||||
interval_inf=(True, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
}[self]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue