refactor: end-of-day commit (sim symbol flow for data import/export & inverse design)

main
Sofus Albert Høgsbro Rose 2024-05-21 22:57:56 +02:00
parent dccf952ad3
commit 353a2c997e
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
24 changed files with 2058 additions and 1710 deletions

View File

@ -27,6 +27,7 @@ dependencies = [
#"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen
"certifi==2021.10.8",
"polars>=0.20.26",
"seaborn[stats]>=0.13.2",
]
## When it comes to dev-dep conflicts:
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.

View File

@ -81,6 +81,7 @@ locket==1.0.0
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.3
# via seaborn
# via tidy3d
ml-dtypes==0.4.0
# via jax
@ -102,8 +103,11 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
# via patsy
# via scipy
# via seaborn
# via shapely
# via statsmodels
# via tidy3d
# via trimesh
# via xarray
@ -114,11 +118,16 @@ packaging==24.0
# via dask
# via h5netcdf
# via matplotlib
# via statsmodels
# via xarray
pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray
partd==1.4.1
# via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0
# via matplotlib
platformdirs==4.2.1
@ -167,13 +176,19 @@ s3transfer==0.5.2
scipy==1.12.0
# via jax
# via jaxlib
# via seaborn
# via statsmodels
# via tidy3d
seaborn==0.13.2
setuptools==69.5.1
# via nodeenv
shapely==2.0.3
# via tidy3d
six==1.16.0
# via patsy
# via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12
termcolor==2.4.0
# via commitizen

View File

@ -59,6 +59,7 @@ llvmlite==0.42.0
locket==1.0.0
# via partd
matplotlib==3.8.3
# via seaborn
# via tidy3d
ml-dtypes==0.4.0
# via jax
@ -78,8 +79,11 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
# via patsy
# via scipy
# via seaborn
# via shapely
# via statsmodels
# via tidy3d
# via trimesh
# via xarray
@ -89,11 +93,16 @@ packaging==24.0
# via dask
# via h5netcdf
# via matplotlib
# via statsmodels
# via xarray
pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray
partd==1.4.1
# via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0
# via matplotlib
polars==0.20.26
@ -132,11 +141,17 @@ s3transfer==0.5.2
scipy==1.12.0
# via jax
# via jaxlib
# via seaborn
# via statsmodels
# via tidy3d
seaborn==0.13.2
shapely==2.0.3
# via tidy3d
six==1.16.0
# via patsy
# via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12
tidy3d==2.6.3
toml==0.10.2

View File

@ -29,9 +29,12 @@ from blender_maxwell.utils import logger
log = logger.get(__name__)
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
@dataclasses.dataclass(frozen=True, kw_only=True)
class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit.
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
Attributes:
values: An ND array-like object of arbitrary numerical type.
@ -44,13 +47,97 @@ class ArrayFlow:
is_sorted: bool = False
####################
# - Computed Properties
####################
@property
def is_symbolic(self) -> bool:
"""Always False, as ArrayFlows are never unrealized."""
return False
def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values)
@functools.cached_property
def mathtype(self) -> spux.MathType:
"""Deduce the `spux.MathType` of the first element of the contained array.
This is generally a heuristic, but because `jax` enforces homogeneous arrays, this is actually a well-defined approach.
"""
return spux.MathType.from_pytype(type(self.values.item(0)))
@functools.cached_property
def physical_type(self) -> spux.MathType:
"""Deduce the `spux.PhysicalType` of the unit."""
return spux.PhysicalType.from_unit(self.unit)
####################
# - Array Features
####################
@property
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
"""Standardized access to `self.values`."""
return self.values
@functools.cached_property
def shape(self) -> int:
"""Shape of the contained array."""
return self.values.shape
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
"""Implement indexing and slicing in a sane way.
- **Integer Index**: For scalar output, return a `sympy` expression of the scalar multiplied by the unit, else just a sympy expression of the value.
- **Slice**: Slice the internal array directly, and wrap the result in a new `ArrayFlow`.
"""
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
if isinstance(subscript, int):
value = self.values[subscript]
if len(value.shape) == 0:
return value * self.unit if self.unit is not None else sp.S(value)
return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
raise NotImplementedError
####################
# - Methods
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
"""Apply an order-preserving function to each element of the array, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
# Compile JAX-Compatible Rescale Function
a = self.mathtype.sp_symbol_a
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value.
@ -88,56 +175,26 @@ class ArrayFlow:
return right_idx
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self:
if self.unit is not None:
return ArrayFlow(
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted
)
####################
# - Unit Transforms
####################
def correct_unit(self, unit: spux.Unit) -> typ.Self:
"""Simply replace the existing unit with the given one.
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
Parameters:
corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
def rescale_to_unit(self, unit: spu.Quantity | None) -> typ.Self:
## TODO: Cache by unit would be a very nice speedup for Viz node.
if self.unit is not None:
return ArrayFlow(
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
unit=unit,
is_sorted=self.is_sorted,
)
def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
"""Rescale the `ArrayFlow` to be expressed in the given unit.
if unit is None:
return self
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
# Compile JAX-Compatible Rescale Function
a = sp.Symbol('a')
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a * self.unit)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
Parameters:
corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
return self.rescale(lambda v: v, new_unit=new_unit)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
raise NotImplementedError

View File

@ -0,0 +1,36 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from . import FlowKind
class ExprInfo(typ.TypedDict):
active_kind: FlowKind
size: spux.NumberSize1D
mathtype: spux.MathType
physical_type: spux.PhysicalType
# Value
default_value: spux.SympyExpr
# Range
default_min: spux.SympyExpr
default_max: spux.SympyExpr
default_steps: int

View File

@ -19,6 +19,7 @@ import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty
log = logger.get(__name__)
@ -51,16 +52,71 @@ class FlowKind(enum.StrEnum):
Capabilities = enum.auto()
# Values
Value = enum.auto()
Array = enum.auto()
Value = enum.auto() ## 'value'
Array = enum.auto() ## 'array'
# Lazy
Func = enum.auto()
Range = enum.auto()
Func = enum.auto() ## 'lazy_func'
Range = enum.auto() ## 'lazy_range'
# Auxiliary
Params = enum.auto()
Info = enum.auto()
Params = enum.auto() ## 'params'
Info = enum.auto() ## 'info'
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
# Values
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
# Lazy
FlowKind.Range: 'Range',
FlowKind.Func: 'Func',
# Auxiliary
FlowKind.Params: 'Params',
FlowKind.Info: 'Info',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
####################
# - Static Properties
####################
@staticproperty
def active_kinds() -> list[typ.Self]:
"""Return a list of `FlowKind`s that are able to be considered "active".
"Active" `FlowKind`s are considered the primary data type of a socket's flow.
For example, for sockets to be linkeable, their active `FlowKind` must generally match.
"""
return [
FlowKind.Value,
FlowKind.Array,
FlowKind.Range,
FlowKind.Func,
]
@property
def socket_shape(self) -> str:
"""Return the socket shape associated with this `FlowKind`.
**ONLY** valid for `FlowKind`s that can be considered "active".
Raises:
ValueError: If this `FlowKind` cannot ever be considered "active".
"""
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
####################
# - Class Methods
@ -69,7 +125,7 @@ class FlowKind(enum.StrEnum):
def scale_to_unit_system(
cls,
kind: typ.Self,
flow_obj,
flow_obj: spux.SympyExpr,
unit_system: spux.UnitSystem,
):
# log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
@ -87,43 +143,3 @@ class FlowKind(enum.StrEnum):
msg = 'Tried to scale unknown kind'
raise ValueError(msg)
####################
# - Computed
####################
@property
def flow_kind(self) -> str:
return {
FlowKind.Value: FlowKind.Value,
FlowKind.Array: FlowKind.Array,
FlowKind.Func: FlowKind.Func,
FlowKind.Range: FlowKind.Range,
}[self]
@property
def socket_shape(self) -> str:
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
####################
# - Blender Enum
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
FlowKind.Range: 'Range',
FlowKind.Func: 'Func',
FlowKind.Params: 'Parameters',
FlowKind.Info: 'Information',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''

View File

@ -18,246 +18,247 @@ import dataclasses
import functools
import typing as typ
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow
from .lazy_range import RangeFlow
log = logger.get(__name__)
LabelArray: typ.TypeAlias = list[str]
# IndexArray: Identifies Discrete Dimension Values
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
## -> RangeFlow (rat|real): Index by unrealized array scaled between boundaries.
## -> LabelArray (int): For int-index arrays, interpret using these labels.
## -> None: Non-Discrete/unrealized indexing; use 'dim.domain'.
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
@dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow:
####################
# - Covariant Input
####################
dim_names: list[str] = dataclasses.field(default_factory=list)
dim_idx: dict[str, ArrayFlow | RangeFlow] = dataclasses.field(
default_factory=dict
) ## TODO: Rename to dim_idxs
"""Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
@functools.cached_property
def dim_has_coords(self) -> dict[str, int]:
return {
dim_name: not (
isinstance(dim_idx, RangeFlow)
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
)
for dim_name, dim_idx in self.dim_idx.items()
}
Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
@functools.cached_property
def dim_lens(self) -> dict[str, int]:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
# From Arrays to Tensors
The best way to illustrate how it works is to specify how raw-array concepts map to an array described by an `InfoFlow`:
@functools.cached_property
def dim_mathtypes(self) -> dict[str, spux.MathType]:
return {
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
}
- **Index**: In raw arrays, the "index" is generally constrained to an integer ring, and has no semantic meaning.
**(Covariant) Dimension**: The "dimension" is an named "index array", which assigns each integer index a **scalar** value of particular mathematical type, name, and unit (if not unitless).
- **Value**: In raw arrays, the "value" is some particular computational type, or another raw array.
**(Contravariant) Output**: The "output" is a strictly named, sized object that can only be produced
@functools.cached_property
def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
In essence, `InfoFlow` allows us to treat raw data as a tensor, then operate on its dimensionality as split into parts whose transform varies _with_ the output (aka. a _covariant_ index), and parts whose transform varies _against_ the output (aka. _contravariant_ value).
@functools.cached_property
def dim_physical_types(self) -> dict[str, spux.PhysicalType]:
return {
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
for dim_name, dim_idx in self.dim_idx.items()
}
## Benefits
The reasons to do this are numerous:
@functools.cached_property
def dim_idx_arrays(self) -> list[jax.Array]:
return [
dim_idx.realize().values
if isinstance(dim_idx, RangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
- **Clarity**: Using `InfoFlow`, it's easy to understand what the data is, and what can be done to it, making it much easier to implement complex operations in math nodes without sacrificing the user's mental model.
- **Zero-Cost Operations**: Transforming indices, "folding" dimensions into the output, and other such operations don't actually do anything to the data, enabling a lot of operations to feel "free" in terms of performance.
- **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
"""
####################
# - Contravariant Output
# - Dimensions: Covariant Index
####################
# Output Information
## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
@property
def output_shape_len(self) -> int:
if self.output_shape is None:
return 0
return len(self.output_shape)
# Pinned Dimension Information
## TODO: Add PhysicalType
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
default_factory=dict
)
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict
)
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
@functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return next(iter(self.dims.keys()))
return None
@functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return list(self.dims.keys())[-1]
return None
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
This happens when the dimension is generated from a symbolic function, as opposed to from discrete observations.
In these cases, the `SimSymbol.domain` of the dimension should be used to determine the overall domain of validity.
Other than that, it's up to the user to select a particular way of indexing.
"""
return self.dims[dim] is None
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim is indexed by a `LabelArray`."""
if dim.mathtype is spux.MathType.Integer:
return isinstance(self.dims[dim], list)
return False
####################
# - Methods
# - Output: Contravariant Value
####################
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self:
output: sim_symbols.SimSymbol
####################
# - Pinned Dimension Values
####################
## -> Whenever a dimension is deleted, we retain what that index value was.
## -> This proves to be very helpful for clear visualization.
pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
default_factory=dict
)
####################
# - Operations: Dimensions
####################
def prepend_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index 0."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx={
_dim_name: (
dim_idx
if _dim_name != dim_name
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
)
for _dim_name, dim_idx in self.dim_idx.items()
dims={dim: dim_idx} | self.dims,
output=self.output,
pinned_values=self.pinned_values,
)
def slice_dim(
self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
) -> typ.Self:
"""Slice a dimensional array by-index along a particular dimension."""
return InfoFlow(
dims={
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else _dim
for _dim, dim_idx in self.dims.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
output=self.output,
pinned_values=self.pinned_values,
)
def replace_dim(
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | RangeFlow]
self,
old_dim: sim_symbols.SimSymbol,
new_dim: sim_symbols.SimSymbol,
new_dim_idx: IndexArray,
) -> typ.Self:
"""Replace a dimension (and its indexing) with a new name and index array/range."""
"""Replace a dimension entirely, in-place, including symbol and index array."""
return InfoFlow(
# Dimensions
dim_names=[
dim_name if dim_name != old_dim_name else new_dim_idx[0]
for dim_name in self.dim_names
],
dim_idx={
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
dims={
(new_dim if _dim == old_dim else _dim): (
new_dim_idx if _dim == old_dim else _dim
)
for dim_name, dim_idx in self.dim_idx.items()
for _dim, dim_idx in self.dims.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
output=self.output,
pinned_values=self.pinned_values,
)
def rescale_dim_idxs(self, new_dim_idxs: dict[str, RangeFlow]) -> typ.Self:
def replace_dims(
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx={
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
for _dim_name, dim_idx in self.dim_idx.items()
dims={
dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
output=self.output,
pinned_values=self.pinned_values,
)
def delete_dimension(self, dim_name: str) -> typ.Self:
"""Delete a dimension."""
def delete_dim(
self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
) -> typ.Self:
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
new_pin = (
{dim_to_remove: self.dims[dim_to_remove][pin_idx]}
if pin_idx is not None
else {}
)
return InfoFlow(
# Dimensions
dim_names=[
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
],
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
dims={
dim: dim_idx
for dim, dim_idx in self.dims.items()
if dim != dim_to_remove
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
output=self.output,
pinned_values=self.pinned_values | new_pin,
)
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
"""Swap the position of two dimensions."""
def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
"""Swap the positions of two dimensions."""
# Compute Swapped Dimension Name List
# Swapped Dimension Keys
def name_swapper(dim_name):
return (
dim_name
if dim_name not in [dim_0_name, dim_1_name]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
if dim_name not in [dim_0, dim_1]
else {dim_0: dim_1, dim_1: dim_0}[dim_name]
)
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
# Compute Info
return InfoFlow(
# Dimensions
dim_names=dim_names,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
output=self.output,
pinned_values=self.pinned_values,
)
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
"""Set the MathType of the output."""
####################
# - Operations: Output
####################
def update_output(self, **kwargs) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
return InfoFlow(
dim_names=self.dim_names,
dim_idx=self.dim_idx,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
dims=self.dims,
output=self.output.update(**kwargs),
pinned_values=self.pinned_values,
)
def collapse_output(
self,
collapsed_name: str,
collapsed_mathtype: spux.MathType,
collapsed_unit: spux.Unit,
) -> typ.Self:
"""Replace the (scalar) output with the given corrected values."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx=self.dim_idx,
output_name=collapsed_name,
output_shape=None,
output_mathtype=collapsed_mathtype,
output_unit=collapsed_unit,
)
####################
# - Operations: Fold
####################
def fold_last_input(self):
"""Fold the last input dimension into the output."""
last_key = list(self.dims.keys())[-1]
last_idx = list(self.dims.values())[-1]
rows = self.output.rows
cols = self.output.cols
match (rows, cols):
case (1, 1):
new_output = self.output.set_size(len(last_idx), 1)
case (_, 1):
new_output = self.output.set_size(rows, len(last_idx))
case (1, _):
new_output = self.output.set_size(len(last_idx), cols)
case (_, _):
raise NotImplementedError ## Not yet :)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names[:-1],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
dims={
dim: dim_idx for dim, dim_idx in self.dims.items() if dim != last_key
},
# Outputs
output_name=self.output_name,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
output=new_output,
pinned_values=self.pinned_values,
)

View File

@ -24,6 +24,8 @@ import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@ -307,6 +309,25 @@ class FuncFlow:
msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)
####################
# - Realization
####################
def realize(
self,
params: ParamsFlow,
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
) -> typ.Self:
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
return self.func(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
####################
# - Composition Operations
####################

View File

@ -28,13 +28,20 @@ from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .array import ArrayFlow
from .flow_kinds import FlowKind
from .lazy_func import FuncFlow
log = logger.get(__name__)
class ScalingMode(enum.StrEnum):
"""Identifier for how to space steps between two boundaries.
Attributes:
Lin: Uniform spacing between two endpoints.
Geom: Log spacing between two endpoints, given as values.
Log: Log spacing between two endpoints, given as powers of a common base.
"""
Lin = enum.auto()
Geom = enum.auto()
Log = enum.auto()
@ -55,36 +62,20 @@ class ScalingMode(enum.StrEnum):
@dataclasses.dataclass(frozen=True, kw_only=True)
class RangeFlow:
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation.
r"""Represents a spaced array using symbolic boundary expressions.
# Advantages
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
## Memory
# Memory Scaling
`ArrayFlow` generally has a memory scaling of $O(n)$.
Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored.
## Symbolic
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol).
# Symbolic Bounds
`self.start` and `self.stop` boundary points are symbolic expressions, within which any element of `self.symbols` can participate.
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow.
**It is the user's responsibility** to ensure that `self.start < self.stop`.
## Performant Unit-Aware Operations
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
`RangeFlow`, by contrast, scales as $O(1)$.
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
## High-Performance Composition and Gradiant
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
There are two nice things about this:
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
Thus, as part of larger computations, the performance properties of `RangeFlow` is extremely favorable.
## Numerical Properties
# Numerical Properties
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes:
@ -108,8 +99,11 @@ class RangeFlow:
unit: spux.Unit | None = None
symbols: frozenset[spux.IntSymbol] = frozenset()
symbols: frozenset[spux.Symbol] = frozenset()
####################
# - Computed Properties
####################
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
@ -121,6 +115,19 @@ class RangeFlow:
"""
return sorted(self.symbols, key=lambda sym: sym.name)
@property
def is_symbolic(self) -> bool:
"""Whether the `RangeFlow` has unrealized symbols."""
return len(self.symbols) > 0
def __len__(self) -> int:
"""Compute the length of the array that would be realized.
Returns:
The number of steps.
"""
return self.steps
@functools.cached_property
def mathtype(self) -> spux.MathType:
"""Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`.
@ -156,13 +163,206 @@ class RangeFlow:
)
return combined_mathtype
def __len__(self):
"""Compute the length of the array to be realized.
####################
# - Methods
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
"""Apply an order-preserving function to each bound, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
return RangeFlow(
start=(
spux.scale_to_unit(new_start, new_unit)
if new_unit is not None
else new_start
),
stop=(
spux.scale_to_unit(new_stop, new_unit)
if new_unit is not None
else new_stop
),
steps=self.steps,
scaling=self.scaling,
unit=new_unit,
symbols=self.symbols,
)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
####################
# - Exporters
####################
@functools.cached_property
def array_generator(
self,
) -> typ.Callable[
[int | float | complex, int | float | complex, int],
jtyp.Inexact[jtyp.Array, ' steps'],
]:
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
Returns:
The number of steps.
A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
"""
return self.steps
jnp_nspace = {
ScalingMode.Lin: jnp.linspace,
ScalingMode.Geom: jnp.geomspace,
ScalingMode.Log: jnp.logspace,
}.get(self.scaling)
if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
return jnp_nspace
@functools.cached_property
def as_func(
self,
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
"""
# Compile JAX Functions for Start/End Expressions
## -> FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
*args: list[int | float | complex],
) -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize_start(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_stop(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_step_size(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
raise NotImplementedError('Non-linear scaling mode not yet suported')
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
return raw_step_size
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
Parameters:
symbol_values: The particular values for each symbol, which will be inserted into the expression of each bound to realize them.
Returns:
An `ArrayFlow` containing this realized `RangeFlow`.
"""
## TODO: Check symbol values for coverage.
return ArrayFlow(
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
unit=self.unit,
is_sorted=True,
)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
"""Standardized access to `self.realize()` when there are no symbols."""
return self.realize()
def __getitem__(self, subscript: slice):
"""Implement indexing and slicing in a sane way.
- **Integer Index**: Not yet implemented.
- **Slice**: Return the `RangeFlow` that creates the same `ArrayFlow` as would be created by computing `self.realize_array`, then slicing that.
"""
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice
start = subscript.start if subscript.start is not None else 0
stop = subscript.stop if subscript.stop is not None else self.steps
step = subscript.step if subscript.step is not None else 1
slice_steps = (stop - start + step - 1) // step
# Compute New Start/Stop
step_size = self.realize_step_size()
new_start = step_size * start
new_stop = new_start + step_size * slice_steps
return RangeFlow(
start=sp.S(new_start),
stop=sp.S(new_stop),
steps=slice_steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
raise NotImplementedError
####################
# - Units
@ -264,231 +464,3 @@ class RangeFlow:
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
)
raise ValueError(msg)
####################
# - Bound Operations
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
return RangeFlow(
start=(
spux.scale_to_unit(new_start, new_unit)
if new_unit is not None
else new_start
),
stop=(
spux.scale_to_unit(new_stop, new_unit)
if new_unit is not None
else new_stop
),
steps=self.steps,
scaling=self.scaling,
unit=new_unit,
symbols=self.symbols,
)
def rescale_bounds(
self,
rescale_func: typ.Callable[
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
],
reverse: bool = False,
) -> typ.Self:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `RangeFlow`.
"""
return RangeFlow(
start=rescale_func(self.start if not reverse else self.stop),
stop=rescale_func(self.stop if not reverse else self.start),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
####################
# - Lazy Representation
####################
@functools.cached_property
def array_generator(
self,
) -> typ.Callable[
[int | float | complex, int | float | complex, int],
jtyp.Inexact[jtyp.Array, ' steps'],
]:
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
Returns:
A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
"""
jnp_nspace = {
ScalingMode.Lin: jnp.linspace,
ScalingMode.Geom: jnp.geomspace,
ScalingMode.Log: jnp.logspace,
}.get(self.scaling)
if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
return jnp_nspace
@functools.cached_property
def as_func(
self,
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
"""
# Compile JAX Functions for Start/End Expressions
## FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
*args: list[int | float | complex],
) -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize_start(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_stop(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_step_size(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
return raw_step_size
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
kind: typ.Literal[FlowKind.Array, FlowKind.Func] = FlowKind.Array,
) -> ArrayFlow | FuncFlow:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `RangeFlow`.
"""
if not set(self.symbols).issubset(set(symbol_values.keys())):
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
raise ValueError(msg)
# Realize Symbols
realized_start = self.realize_start(symbol_values)
realized_stop = self.realize_stop(symbol_values)
# Return Linspace / Logspace
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(realized_start, realized_stop, self.steps)
if kind == FlowKind.Array:
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
if kind == FlowKind.Func:
return FuncFlow(func=gen_array, supports_jax=True)
msg = f'Invalid kind: {kind}'
raise TypeError(msg)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
return self.realize()
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice
start = subscript.start if subscript.start is not None else 0
stop = subscript.stop if subscript.stop is not None else self.steps
step = subscript.step if subscript.step is not None else 1
slice_steps = (stop - start + step - 1) // step
# Compute New Start/Stop
step_size = self.realize_step_size()
new_start = step_size * start
new_stop = new_start + step_size * slice_steps
return RangeFlow(
start=sp.S(new_start),
stop=sp.S(new_stop),
steps=slice_steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
raise NotImplementedError

View File

@ -22,35 +22,22 @@ from types import MappingProxyType
import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import logger, sim_symbols
from .expr_info import ExprInfo
from .flow_kinds import FlowKind
from .info import InfoFlow
# from .info import InfoFlow
log = logger.get(__name__)
class ExprInfo(typ.TypedDict):
active_kind: FlowKind
size: spux.NumberSize1D
mathtype: spux.MathType
physical_type: spux.PhysicalType
# Value
default_value: spux.SympyExpr
# Range
default_min: spux.SympyExpr
default_max: spux.SympyExpr
default_steps: int
@dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow:
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[spux.Symbol] = frozenset()
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
@functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]:
@ -66,21 +53,18 @@ class ParamsFlow:
####################
def scaled_func_args(
self,
unit_system: spux.UnitSystem,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
):
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
For all `arg`s in `self.func_args`, the following operations are performed:
- **Unit System**: If `arg`
For all `arg`s in `self.func_args`, the following operations are performed.
Notes:
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
"""
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg)
@ -97,7 +81,7 @@ class ParamsFlow:
def scaled_func_kwargs(
self,
unit_system: spux.UnitSystem,
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
@ -145,9 +129,7 @@ class ParamsFlow:
####################
# - Generate ExprSocketDef
####################
def sym_expr_infos(
self, info: InfoFlow, use_range: bool = False
) -> dict[str, ExprInfo]:
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
@ -169,26 +151,35 @@ class ParamsFlow:
The `ExprInfo`s can be directly defererenced `**expr_info`)
"""
for sim_sym in self.sorted_symbols:
if use_range and sim_sym.mathtype is spux.MathType.Complex:
msg = 'No support for complex range in ExprInfo'
raise NotImplementedError(msg)
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
msg = 'No support for non-scalar elements of range in ExprInfo'
raise NotImplementedError(msg)
if sim_sym.rows > 3 or sim_sym.cols > 1:
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
raise NotImplementedError(msg)
return {
sym.name: {
sim_sym.name: {
# Declare Kind/Size
## -> Kind: Value prevents user-alteration of config.
## -> Size: Always scalar, since symbols are scalar (for now).
'active_kind': FlowKind.Value,
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
'size': spux.NumberSize1D.Scalar,
# Declare MathType/PhysicalType
## -> MathType: Lookup symbol name in info dimensions.
## -> PhysicalType: Same.
'mathtype': info.dim_mathtypes[sym.name],
'physical_type': info.dim_physical_types[sym.name],
# TODO: Default Values
'mathtype': self.dims[sim_sym].mathtype,
'physical_type': self.dims[sim_sym].physical_type,
# TODO: Default Value
# FlowKind.Value: Default Value
#'default_value':
# FlowKind.Range: Default Min/Max/Steps
#'default_min':
#'default_max':
#'default_steps':
'default_min': sim_sym.domain.start,
'default_max': sim_sym.domain.end,
'default_steps': 50,
}
for sym in self.sorted_symbols
if sym.name in info.dim_names
for sim_sym in self.sorted_symbols
}

View File

@ -489,11 +489,11 @@ class DataFileFormat(enum.StrEnum):
E = DataFileFormat
match self:
case E.Csv:
return len(info.dim_names) + info.output_shape_len <= 2
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
case E.Npy:
return True
case E.Txt | E.TxtGz:
return len(info.dim_names) + info.output_shape_len <= 2
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
@property
def saver(
@ -510,9 +510,9 @@ class DataFileFormat(enum.StrEnum):
# Extract Input Coordinates
dim_columns = {
dim_name: np.array(info.dim_idx_arrays[i])
for i, dim_name in enumerate(info.dim_names)
}
dim.name: np.array(dim_idx.realize_array)
for i, (dim, dim_idx) in enumerate(info.dims)
} ## TODO: realize_array might not be defined on some index arrays
# Declare Function to Extract Output Values
output_columns = {}
@ -524,14 +524,14 @@ class DataFileFormat(enum.StrEnum):
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
if bool(np.any(np.iscomplex(data_col))):
output_columns |= {
f'{info.output_name}{output_idx_str}_re': np.real(data_col),
f'{info.output_name}{output_idx_str}_im': np.imag(data_col),
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
}
# Else: Use Array Directly
else:
output_columns |= {
f'{info.output_name}{output_idx_str}': data_col,
f'{info.output.name}{output_idx_str}': data_col,
}
## TODO: Maybe a check to ensure dtype!=object?
@ -605,11 +605,11 @@ class DataFileFormat(enum.StrEnum):
E = DataFileFormat
match self:
case E.Csv:
return len(info.dim_names) + (info.output_shape_len + 1) <= 2
return len(info.dims) + (info.output.rows + input.outputs.cols - 1) <= 2
case E.Npy:
return True
case E.Txt | E.TxtGz:
return len(info.dim_names) + (info.output_shape_len + 1) <= 2
return len(info.dims) + (info.output.rows + info.output.cols - 1) <= 2
def supports_metadata(self) -> bool:
E = DataFileFormat

View File

@ -48,7 +48,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Mechanical
@ -58,7 +57,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
_PT.Force: spux.micronewton,
# Luminal
# Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
} ## TODO: Load (dynamically?) from addon preferences
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
@ -75,11 +73,9 @@ UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Luminal
# Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
}

View File

@ -17,15 +17,15 @@
"""Implements `ExtractDataNode`."""
import enum
import functools
import typing as typ
import bpy
import jax
import numpy as np
import jax.numpy as jnp
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct
@ -37,6 +37,176 @@ log = logger.get(__name__)
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
####################
# - Monitor Label Arrays
####################
def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
"""
monitor_data = sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
case 'Permittivity':
return ['eps_xx', 'eps_yy', 'eps_zz']
case 'Flux' | 'FluxTime':
return ['flux']
case (
'FieldProjectionAngle'
| 'FieldProjectionCartesian'
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
"""Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
xarr = getattr(monitor_data, monitor_attr, None)
if xarr is None:
return None
def mk_idx_array(axis: str) -> ct.ArrayFlow:
return ct.ArrayFlow(
values=xarr.get_index(axis).values,
unit=symbols[axis].unit,
is_sorted=True,
)
# Compute InfoFlow from XArray
symbols = {
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
'z': sim_symbols.space_z(spu.micrometer),
# Spherical
'r': sim_symbols.ang_r(spu.micrometer),
'theta': sim_symbols.ang_theta(spu.radian),
'phi': sim_symbols.ang_phi(spu.radian),
# Freq|Time
'f': sim_symbols.freq(spu.hertz),
't': sim_symbols.t(spu.second),
# Power Flux
'flux': sim_symbols.flux(spu.watt),
# Cartesian Fields
'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Spherical Fields
'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Wavevector
'ux': sim_symbols.dir_x(spu.watt),
'uy': sim_symbols.dir_y(spu.watt),
# Diffraction Orders
'orders_x': sim_symbols.diff_order_x(None),
'orders_y': sim_symbols.diff_order_y(None),
}
match monitor_data.type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldTime':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'Flux':
return ct.InfoFlow(
dims={
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FluxTime':
return ct.InfoFlow(
dims={
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionAngle':
return ct.InfoFlow(
dims={
symbols['r']: mk_idx_array('r'),
symbols['theta']: mk_idx_array('theta'),
symbols['phi']: mk_idx_array('phi'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionKSpace':
return ct.InfoFlow(
dims={
symbols['ux']: mk_idx_array('ux'),
symbols['uy']: mk_idx_array('uy'),
symbols['r']: mk_idx_array('r'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'Diffraction':
return ct.InfoFlow(
dims={
symbols['orders_x']: mk_idx_array('orders_x'),
symbols['orders_y']: mk_idx_array('orders_y'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
return None
####################
# - Node
####################
class ExtractDataNode(base.MaxwellSimNode):
"""Extract data from sockets for further analysis.
@ -45,31 +215,21 @@ class ExtractDataNode(base.MaxwellSimNode):
Monitor Data: Extract `Expr`s from monitor data by-component.
Attributes:
extract_filter: Identifier for data to extract from the input.
monitor_attr: Identifier for data to extract from the input.
"""
node_type = ct.NodeType.ExtractData
bl_label = 'Extract'
input_socket_sets: typ.ClassVar = {
'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()},
'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
}
output_socket_sets: typ.ClassVar = {
'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
'Monitor Data': {'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func)},
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
}
####################
# - Properties
####################
extract_filter: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_extract_filters(),
cb_depends_on={'sim_data_monitor_nametype', 'monitor_data_type'},
)
####################
# - Computed: Sim Data
# - Properties: Monitor Name
####################
@events.on_value_changed(
socket_name='Sim Data',
@ -99,198 +259,49 @@ class ExtractDataNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'sim_data'})
def sim_data_monitor_nametype(self) -> dict[str, str] | None:
"""For simulation data, deduces a map from the monitor name to the monitor "type".
"""Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
Return:
The name to type of monitors in the simulation data.
"""
if self.sim_data is not None:
return {
monitor_name: monitor_data.type
monitor_name: monitor_data.type.removesuffix('Data')
for monitor_name, monitor_data in self.sim_data.monitor_data.items()
}
return None
####################
# - Computed Properties: Monitor Data
####################
@events.on_value_changed(
socket_name='Monitor Data',
input_sockets={'Monitor Data'},
input_sockets_optional={'Monitor Data': True},
monitor_name: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_monitor_names(),
cb_depends_on={'sim_data_monitor_nametype'},
)
def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data'])
if has_monitor_data:
self.monitor_data = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def monitor_data(self) -> TDMonitorData | None:
"""Extracts the monitor data from the input socket.
Return:
Either the monitor data, if available, or None.
"""
monitor_data = self._compute_input(
'Monitor Data', kind=ct.FlowKind.Value, optional=True
)
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data:
return monitor_data
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data'})
def monitor_data_type(self) -> str | None:
r"""For monitor data, deduces the monitor "type".
- **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain.
- **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor.
- **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain.
For planes, an explicit direction is defined.
For volumes, the the integral of all outgoing energy is stored.
- **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection.
- **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order.
def search_monitor_names(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
Notes:
Should be invalidated with (before) `self.monitor_data_attrs`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
return self.monitor_data.type.removesuffix('Data')
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data_type'})
def monitor_data_attrs(self) -> list[str] | None:
r"""For monitor data, deduces the valid data-containing attributes.
The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups.
- **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor.
- **Permittivity**: Specifically `['xx', 'yy', 'zz']`.
- **Flux(Time)**: Only `['flux']`.
- **FieldProjection(...)**: All of $r$, $\theta$, $\phi$ for both `E` and `H`.
- **Diffraction**: Same as `FieldProjection`.
Notes:
Should be invalidated after with `self.monitor_data_type`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if hasattr(self.monitor_data, field_component)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return ['xx', 'yy', 'zz']
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return ['flux']
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
return None
####################
# - Extraction Filter Search
####################
def search_extract_filters(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`.
Notes:
Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns:
Valid `self.extract_filter` in a format compatible with dynamic `EnumProperty`.
Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
"""
if self.sim_data_monitor_nametype is not None:
return [
(monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i)
(
monitor_name,
monitor_name,
monitor_type + ' Monitor Data',
'',
i,
)
for i, (monitor_name, monitor_type) in enumerate(
self.sim_data_monitor_nametype.items()
)
]
if self.monitor_data_attrs is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return [
(monitor_attr, monitor_attr, f' ε_{monitor_attr}', '', i)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return [
(
monitor_attr,
monitor_attr,
'Power flux integral through the plane / out of the volume',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
return []
####################
@ -303,10 +314,9 @@ class ExtractDataNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header.
"""
has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_attrs is not None
if has_sim_data or has_monitor_data:
return f'Extract: {self.extract_filter}'
if has_sim_data:
return f'Extract: {self.monitor_name}'
return self.bl_label
@ -316,336 +326,115 @@ class ExtractDataNode(base.MaxwellSimNode):
Parameters:
col: UI target for drawing.
"""
col.prop(self, self.blfields['extract_filter'], text='')
col.prop(self, self.blfields['monitor_name'], text='')
####################
# - FlowKind.Value: Sim Data -> Monitor Data
# - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor Data',
kind=ct.FlowKind.Value,
# Loaded
props={'extract_filter'},
input_sockets={'Sim Data'},
input_sockets_optional={'Sim Data': True},
)
def compute_monitor_data(
self, props: dict, input_sockets: dict
) -> TDMonitorData | ct.FlowSignal:
"""Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`.
Returns:
Monitor data, if available, else `ct.FlowSignal.FlowPending`.
"""
extract_filter = props['extract_filter']
sim_data = input_sockets['Sim Data']
has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data and extract_filter is not None:
return sim_data.monitor_data[extract_filter]
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Array|Func: Monitor Data -> Expr
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Array,
# Loaded
props={'extract_filter'},
input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
)
def compute_expr(
self, props: dict, input_sockets: dict
) -> jax.Array | ct.FlowSignal:
"""Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it.
Uses the internal `xarray` data returned by Tidy3D.
By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy.
Returns:
The data array, if available, else `ct.FlowSignal.FlowPending`.
"""
extract_filter = props['extract_filter']
monitor_data = input_sockets['Monitor Data']
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data and extract_filter is not None:
xarray_data = getattr(monitor_data, extract_filter)
return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
# Trigger
'Expr',
kind=ct.FlowKind.Func,
# Loaded
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Array},
output_sockets_optional={'Expr': True},
props={'monitor_name'},
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
)
def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.FuncFlow | None:
"""Declare `Expr:Func` by creating a simple function that directly wraps `Expr:Array`.
def compute_expr(
self, props: dict, input_sockets: dict
) -> ct.FuncFlow | ct.FlowSignal:
sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
Returns:
The composable function array, if available, else `ct.FlowSignal.FlowPending`.
"""
output_expr = output_sockets['Expr']
has_output_expr = not ct.FlowSignal.check(output_expr)
has_sim_data = not ct.FlowSignal.check(sim_data)
if has_output_expr:
return ct.FuncFlow(func=lambda: output_expr.values, supports_jax=True)
if has_sim_data and monitor_name is not None:
monitor_data = sim_data.get(monitor_name)
if monitor_data is not None:
# Extract Valid Index Labels
## -> The first output axis will be integer-indexed.
## -> Each integer will have a string label.
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
# Generate FuncFlow Per Index Label
## -> We extract each XArray as an attribute of monitor_data.
## -> We then bind its values into a unique func_flow.
## -> This lets us 'stack' then all along the first axis.
func_flows = []
for idx_label in idx_labels:
xarr = getattr(monitor_data, idx_label)
func_flows.append(
ct.FuncFlow(
func=lambda xarr=xarr: xarr.values,
supports_jax=True,
)
)
# Concatenate and Stack Unified FuncFlow
## -> First, 'reduce' lets us __or__ all the FuncFlows together.
## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
enclosing_func=lambda data: jnp.stack(data, axis=0)
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params: Monitor Data -> Expr
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Params},
)
def compute_data_params(self) -> ct.ParamsFlow:
def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
sim_params = input_sockets['Sim Data']
has_sim_params = not ct.FlowSignal.check(sim_params)
if has_sim_params:
return sim_params
return ct.ParamsFlow()
####################
# - FlowKind.Info: Monitor Data -> Expr
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
# Loaded
props={'monitor_data_type', 'extract_filter'},
input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
props={'monitor_name'},
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
)
def compute_extracted_data_info(
self, props: dict, input_sockets: dict
) -> ct.InfoFlow:
def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
Returns:
Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
"""
monitor_data = input_sockets['Monitor Data']
monitor_data_type = props['monitor_data_type']
extract_filter = props['extract_filter']
sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
has_monitor_data = not ct.FlowSignal.check(monitor_data)
has_sim_data = not ct.FlowSignal.check(sim_data)
# Edge Case: Dangling 'flux' Access on 'FieldMonitor'
## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields.
## -> We don't allow 'flux' attribute access, but it can dangle.
## -> (The method is called when updating each depschain component.)
if monitor_data_type == 'Field' and extract_filter == 'flux':
if not has_sim_data or monitor_name is None:
return ct.FlowSignal.FlowPending
# Retrieve XArray
if has_monitor_data and extract_filter is not None:
xarr = getattr(monitor_data, extract_filter, None)
if xarr is None:
return ct.FlowSignal.FlowPending
else:
return ct.FlowSignal.FlowPending
# Extract Data
## -> All monitor_data.<idx_label> have the exact same InfoFlow.
## -> So, just construct an InfoFlow w/prepended labelled dimension.
monitor_data = sim_data.get(monitor_name)
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
info = extract_info(monitor_data, idx_labels[0])
# Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian
if monitor_data_type in {
'Field',
'Permittivity',
#'FieldProjectionCartesian',
}:
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 'f'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## XYZT: FieldTime
if monitor_data_type == 'FieldTime':
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 't'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.second,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## F: Flux
if monitor_data_type == 'Flux':
return ct.InfoFlow(
dim_names=['f'],
dim_idx={
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## T: FluxTime
if monitor_data_type == 'FluxTime':
return ct.InfoFlow(
dim_names=['t'],
dim_idx={
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## RThetaPhiF: FieldProjectionAngle
if monitor_data_type == 'FieldProjectionAngle':
return ct.InfoFlow(
dim_names=['r', 'theta', 'phi', 'f'],
dim_idx={
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
}
| {
c: ct.ArrayFlow(
values=xarr.get_index(c).values,
unit=spu.radian,
is_sorted=True,
)
for c in ['r', 'theta', 'phi']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## UxUyRF: FieldProjectionKSpace
if monitor_data_type == 'FieldProjectionKSpace':
return ct.InfoFlow(
dim_names=['ux', 'uy', 'r', 'f'],
dim_idx={
c: ct.ArrayFlow(
values=xarr.get_index(c).values, unit=None, is_sorted=True
)
for c in ['ux', 'uy']
}
| {
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## OrderxOrderyF: Diffraction
if monitor_data_type == 'Diffraction':
return ct.InfoFlow(
dim_names=['orders_x', 'orders_y', 'f'],
dim_idx={
f'orders_{c}': ct.ArrayFlow(
values=xarr.get_index(f'orders_{c}').values,
unit=None,
is_sorted=True,
)
for c in ['x', 'y']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"'
raise RuntimeError(msg)
return info.prepend_dim(sim_symbols.idx, idx_labels)
####################

View File

@ -98,29 +98,29 @@ class FilterOperation(enum.StrEnum):
operations = []
# Slice
if info.dim_names:
if info.dims:
operations.append(FO.SliceIdx)
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in list(info.dim_lens.values()):
if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1)
## Pin | PinIdx
## -> There must be a dimension, full stop.
if info.dim_names:
if info.dims:
operations += [FO.Pin, FO.PinIdx]
# Reinterpret
## Swap
## -> There must be at least two dimensions.
if len(info.dim_names) >= 2: # noqa: PLR2004
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
## SetDim
## -> There must be a dimension to correct.
if info.dim_names:
if info.dims:
operations.append(FO.SetDim)
return operations
@ -158,33 +158,33 @@ class FilterOperation(enum.StrEnum):
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
match self:
case FO.SliceIdx:
return info.dim_names
case FO.SliceIdx | FO.Swap:
return info.dims
# PinLen1: Only allow dimensions with length=1.
case FO.PinLen1:
return [
dim_name
for dim_name in info.dim_names
if info.dim_lens[dim_name] == 1
dim
for dim, dim_idx in info.dims.items()
if dim_idx is not None and len(dim_idx) == 1
]
# Pin: Only allow dimensions with known indexing.
case FO.Pin:
# Pin: Only allow dimensions with discrete index.
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
case FO.Pin | FO.PinIdx:
return [
dim_name
for dim_name in info.dim_names
if info.dim_has_coords[dim_name] != 0
dim
for dim, dim_idx in info.dims
if dim_idx is not None and len(dim_idx) > 0
]
case FO.PinIdx | FO.Swap:
return info.dim_names
case FO.SetDim:
return [
dim_name
for dim_name in info.dim_names
if info.dim_mathtypes[dim_name] == spux.MathType.Integer
dim
for dim, dim_idx in info.dims
if dim_idx is not None
and not isinstance(dim_idx, list)
and dim_idx.mathtype == spux.MathType.Integer
]
return []
@ -224,22 +224,22 @@ class FilterOperation(enum.StrEnum):
def transform_info(
self,
info: ct.InfoFlow,
dim_0: str,
dim_1: str,
dim_0: sim_symbols.SimSymbol,
dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]]
| None = None,
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
):
FO = FilterOperation
return {
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dimension(dim_0),
FO.Pin: lambda: info.delete_dimension(dim_0),
FO.PinIdx: lambda: info.delete_dimension(dim_0),
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*corrected_dim),
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
}[self]()
@ -318,11 +318,11 @@ class FilterMathNode(base.MaxwellSimNode):
####################
# - Properties: Dimension Selection
####################
dim_0: enum.StrEnum = bl_cache.BLField(
active_dim_0: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
dim_1: enum.StrEnum = bl_cache.BLField(
active_dim_1: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
@ -335,40 +335,23 @@ class FilterMathNode(base.MaxwellSimNode):
]
return []
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
def dim_0(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_0 is not None:
return self.expr_info.dim_by_name(self.active_dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
def dim_1(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_by_name(self.active_dim_1)
return None
####################
# - Properties: Slice
####################
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
####################
# - Properties: Unit
####################
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
sim_symbols.CommonSimSymbol.X
)
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'set_dim_symbol'},
)
def search_valid_units(self) -> list[ct.BLEnumElement]:
"""Compute Blender enum elements of valid units for the current `physical_type`."""
physical_type = self.set_dim_symbol.sim_symbol.physical_type
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
def set_dim_unit(self) -> spux.Unit | None:
if self.set_dim_active_unit is not None:
return spux.unit_str_to_unit(self.set_dim_active_unit)
return None
####################
# - UI
####################
@ -378,27 +361,27 @@ class FilterMathNode(base.MaxwellSimNode):
# Slice
case FO.SliceIdx:
slice_str = ':'.join([str(v) for v in self.slice_tuple])
return f'Filter: {self.dim_0}[{slice_str}]'
return f'Filter: {self.active_dim_0}[{slice_str}]'
# Pin
case FO.PinLen1:
return f'Filter: Pin {self.dim_0}[0]'
return f'Filter: Pin {self.active_dim_0}[0]'
case FO.Pin:
return f'Filter: Pin {self.dim_0}[...]'
return f'Filter: Pin {self.active_dim_0}[...]'
case FO.PinIdx:
pin_idx_axis = self._compute_input(
'Axis', kind=ct.FlowKind.Value, optional=True
)
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
if has_pin_idx_axis:
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]'
return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
return self.bl_label
# Reinterpret
case FO.Swap:
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]'
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
case FO.SetDim:
return f'Filter: Set [{self.dim_0}]'
return f'Filter: Set [{self.active_dim_0}]'
case _:
return self.bl_label
@ -409,20 +392,15 @@ class FilterMathNode(base.MaxwellSimNode):
if self.operation is not None:
match self.operation.num_dim_inputs:
case 1:
layout.prop(self, self.blfields['dim_0'], text='')
layout.prop(self, self.blfields['active_dim_0'], text='')
case 2:
row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='')
row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['active_dim_1'], text='')
if self.operation is FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
if self.operation is FilterOperation.SetDim:
row = layout.row(align=True)
row.prop(self, self.blfields['set_dim_symbol'], text='')
row.prop(self, self.blfields['set_dim_active_unit'], text='')
####################
# - Events
####################
@ -450,50 +428,47 @@ class FilterMathNode(base.MaxwellSimNode):
if not has_info:
return
# Pin Dim by-Value: Synchronize Input Socket
## -> The user will be given a socket w/correct mathtype, unit, etc. .
## -> Internally, this value will map to a particular index.
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
# Deduce Pinned Information
pinned_unit = info.dim_units[props['dim_0']]
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
wanted_mathtype = (
spux.MathType.Complex
if pinned_mathtype == spux.MathType.Complex
and spux.MathType.Complex in pinned_physical_type.valid_mathtypes
else spux.MathType.Real
)
dim_0 = props['dim_0']
# Get Current and Wanted Socket Defs
## -> 'Value' may already exist. If not, all is well.
# Loose Sockets: Pin Dim by-Value
## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
props['operation'] is FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Value')
# Determine Whether to Construct
## -> If nothing needs to change, then nothing changes.
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != pinned_physical_type
or current_bl_socket.mathtype != wanted_mathtype
or current_bl_socket.physical_type != dim.physical_type
or current_bl_socket.mathtype != dim.mathtype
):
self.loose_input_sockets = {
'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
physical_type=pinned_physical_type,
mathtype=wanted_mathtype,
default_unit=pinned_unit,
physical_type=dim.physical_type,
mathtype=dim.mathtype,
default_unit=dim.unit,
),
}
# Pin Dim by-Index: Synchronize Input Socket
## -> The user will be given a simple integer socket.
# Loose Sockets: Pin Dim by-Value
## -> Works with discrete points / labelled integers.
elif (
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None
props['operation'] is FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Axis')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.mathtype != spux.MathType.Integer
@ -505,28 +480,26 @@ class FilterMathNode(base.MaxwellSimNode):
)
}
# Set Dim: Synchronize Input Socket
# Loose Sockets: Set Dim
## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis.
elif (
props['operation'] is FilterOperation.SetDim
and props['dim_0'] is not None
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
):
# Deduce Axis Information
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Dim')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Func
or current_bl_socket.mathtype != spux.MathType.Real
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.mathtype != dim.mathtype
or current_bl_socket.physical_type != dim.physical_type
):
self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.NonPhysical,
physical_type=dim.physical_type,
mathtype=dim.mathtype,
default_unit=dim.unit,
show_func_ui=False,
show_info_columns=True,
)
}
@ -553,25 +526,20 @@ class FilterMathNode(base.MaxwellSimNode):
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_info = not ct.FlowSignal.check(info)
# Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if (
has_lazy_func
and has_info
and operation is not None
and operation.are_dims_valid(info, dim_0, dim_1)
):
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None
slice_tuple = (
props['slice_tuple']
if self.operation is FilterOperation.SliceIdx
else None
)
axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple),
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args,
supports_jax=True,
)
@ -588,8 +556,6 @@ class FilterMathNode(base.MaxwellSimNode):
'dim_1',
'operation',
'slice_tuple',
'set_dim_symbol',
'set_dim_active_unit',
},
input_sockets={'Expr', 'Dim'},
input_socket_kinds={
@ -601,14 +567,15 @@ class FilterMathNode(base.MaxwellSimNode):
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
operation = props['operation']
info = input_sockets['Expr']
dim_coords = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
dim_symbol = props['set_dim_symbol']
dim_active_unit = props['set_dim_active_unit']
has_info = not ct.FlowSignal.check(info)
has_dim_coords = not ct.FlowSignal.check(dim_coords)
# Dim (Op.SetDim)
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
has_dim_func = not ct.FlowSignal.check(dim_func)
has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info)
@ -619,44 +586,42 @@ class FilterMathNode(base.MaxwellSimNode):
if has_info and operation is not None:
# Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim:
new_dim = (
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
)
if (
dim_0 is not None
# Check Replaced Dimension
and has_dim_coords
and len(dim_coords.func_args) == 1
and dim_coords.func_args[0] is spux.MathType.Integer
and not dim_coords.func_kwargs
and dim_coords.supports_jax
# Check Params
and has_dim_params
and len(dim_params.func_args) == 1
and not dim_params.func_kwargs
# Check Info
and new_dim is not None
and has_dim_info
and has_dim_params
# Check New Dimension Index Array Sizing
and len(dim_info.dims) == 1
and dim_info.output.rows == 1
and dim_info.output.cols == 1
# Check Lack of Params Symbols
and not dim_params.symbols
# Check Expr Dim | New Dim Compatibility
and info.has_idx_discrete(dim_0)
and dim_info.has_idx_discrete(new_dim)
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
):
# Retrieve Dimension Coordinate Array
## -> It must be strictly compatible.
values = dim_coords.func_jax(int(dim_params.func_args[0]))
if (
len(values.shape) != 1
or values.shape[0] != info.dim_lens[dim_0]
):
return ct.FlowSignal.FlowPending
values = dim_func.realize(dim_params, spux.UNITS_SI)
# Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced.
if dim_active_unit is not None:
dim_unit = spux.unit_str_to_unit(dim_active_unit)
else:
dim_unit = None
new_dim_idx = ct.ArrayFlow(
values=values,
unit=dim_unit,
)
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)]
unit=spux.convert_to_unit_system(
dim_info.output.unit, spux.UNITS_SI
),
).rescale_to_unit(dim_info.output.unit)
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
return operation.transform_info(
info, dim_0, dim_1, corrected_dim=corrected_dim
info, dim_0, dim_1, replaced_dim=replaced_dim
)
return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
@ -702,7 +667,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search.
if props['operation'] is FilterOperation.Pin and has_pinned_value:
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)

View File

@ -23,7 +23,7 @@ import bpy
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
@ -153,40 +153,38 @@ class MapOperation(enum.StrEnum):
# - Ops from Shape
####################
@staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation
element_ops = [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match shape:
case 'noshape':
return []
match (info.output.rows, info.output.cols):
case (1, 1):
return element_ops
# By Number
case None:
return [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
case (_, 1):
return [*element_ops, MO.Norm2]
match len(shape):
# By Vector
case 1:
return [
MO.Norm2,
]
# By Matrix
case 2:
case (rows, cols) if rows == cols:
## TODO: Check hermitian/posdef for cholesky.
## - Can we even do this with just the output symbol approach?
return [
*element_ops,
MO.Det,
MO.Cond,
MO.NormFro,
@ -201,6 +199,18 @@ class MapOperation(enum.StrEnum):
MO.Svd,
]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
return []
####################
@ -288,41 +298,76 @@ class MapOperation(enum.StrEnum):
def transform_info(self, info: ct.InfoFlow):
MO = MapOperation
return {
# By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real),
MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector
MO.Norm2: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('v', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
MO.Norm2: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# By Matrix
MO.Det: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=info.output_mathtype,
collapsed_unit=info.output_unit,
MO.Det: lambda: info.update_output(
rows=1,
cols=1,
),
MO.Cond: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=None,
MO.Cond: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
),
MO.NormFro: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
MO.NormFro: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
MO.Rank: lambda: info.collapse_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name),
collapsed_mathtype=spux.MathType.Integer,
collapsed_unit=None,
MO.Rank: lambda: info.update_output(
mathtype=spux.MathType.Integer,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
## TODO: Matrix -> Vec
## TODO: Matrix -> Matrices
}.get(self, lambda: info)()
# Matrix -> Vector ## TODO: ALL OF THESE
MO.Diag: lambda: info,
MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()
####################
@ -435,29 +480,26 @@ class MapMathNode(base.MaxwellSimNode):
)
if has_info and not info_pending:
self.expr_output_shape = bl_cache.Signal.InvalidateCache
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None:
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info.output_shape
return 'noshape'
return info
return None
operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_output_shape'},
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape != 'noshape':
if self.info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
MapOperation.by_element_shape(self.expr_output_shape)
)
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
]
return []
@ -474,7 +516,7 @@ class MapMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='')
####################
# - FlowKind.Value|Func
# - FlowKind.Value
####################
@events.computes_output_socket(
'Expr',
@ -495,6 +537,9 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@ -518,7 +563,7 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Info|Params
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
@ -538,6 +583,9 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,

View File

@ -107,32 +107,31 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform
## Freq <-> VacWL
for dim_name in info.dim_names:
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
for dim in info.dims:
if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL)
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq)
# Fold
## (Last) Int Dim (=2) to Complex
if len(info.dim_names) >= 1:
last_dim_name = info.dim_names[-1]
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
if len(info.dims) >= 1:
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
operations.append(TO.IntDimToComplex)
## To Vector
if len(info.dim_names) >= 1:
if len(info.dims) >= 1:
operations.append(TO.DimToVec)
## To Matrix
if len(info.dim_names) >= 2: # noqa: PLR2004
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(TO.DimsToMat)
# Fourier
## 1D Fourier
if info.dim_names:
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
if info.dims:
last_physical_type = info.last_dim.physical_type
if last_physical_type == spux.PhysicalType.Time:
operations.append(TO.FFT1D)
if last_physical_type == spux.PhysicalType.Freq:
@ -188,15 +187,15 @@ class TransformOperation(enum.StrEnum):
unit: spux.Unit | None = None,
) -> ct.InfoFlow | None:
TO = TransformOperation
if not info.dim_names:
if not info.dims:
return None
return {
# Index
# Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.dim_names[-1]),
(f_dim := info.last_dim),
[
'wl',
info.dim_idx[f_dim].rescale(
sim_symbols.wl(spu.nanometer),
info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spu.nanometer,
@ -204,10 +203,10 @@ class TransformOperation(enum.StrEnum):
],
),
TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.dim_names[-1]),
(wl_dim := info.last_dim),
[
'f',
info.dim_idx[wl_dim].rescale(
sim_symbols.freq(spux.THz),
info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spux.THz,
@ -215,24 +214,24 @@ class TransformOperation(enum.StrEnum):
],
),
# Fold
TO.IntDimToComplex: lambda: info.delete_dimension(
info.dim_names[-1]
).set_output_mathtype(spux.MathType.Complex),
TO.DimToVec: lambda: info.shift_last_input,
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input,
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
),
TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier
TO.FFT1D: lambda: info.replace_dim(
info.dim_names[-1],
info.last_dim,
[
'f',
ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz),
sim_symbols.freq(spux.THz),
None,
],
),
TO.InvFFT1D: info.replace_dim(
info.dim_names[-1],
info.last_dim,
[
't',
ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.second),
sim_symbols.t(spu.second),
None,
],
),
}.get(self, lambda: info)()

View File

@ -38,7 +38,6 @@ class VizMode(enum.StrEnum):
**NOTE**: >1D output dimensions currently have no viz.
Plots for `() -> `:
- Hist1D: Bin-summed distribution.
- BoxPlot1D: Box-plot describing the distribution.
Plots for `() -> `:
@ -61,7 +60,6 @@ class VizMode(enum.StrEnum):
- Heatmap3D: Colormapped field with value at each voxel.
"""
Hist1D = enum.auto()
BoxPlot1D = enum.auto()
Curve2D = enum.auto()
@ -78,42 +76,38 @@ class VizMode(enum.StrEnum):
@staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
EMPTY = ()
Z = spux.MathType.Integer
R = spux.MathType.Real
VM = VizMode
valid_viz_modes = {
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
((Z), (None, R)): [
VM.Hist1D,
return {
((Z), (1, 1, R)): [
VM.BoxPlot1D,
],
((R,), (None, R)): [
((R,), (1, 1, R)): [
VM.Curve2D,
VM.Points2D,
VM.Bar,
],
((R, Z), (None, R)): [
((R, Z), (1, 1, R)): [
VM.Curves2D,
VM.FilledCurves2D,
],
((R, R), (None, R)): [
((R, R), (1, 1, R)): [
VM.Heatmap2D,
],
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
((R, R, R), (1, 1, R)): [
VM.SqueezedHeatmap2D,
VM.Heatmap3D,
],
}.get(
(
tuple(info.dim_mathtypes.values()),
(info.output_shape, info.output_mathtype),
)
tuple([dim.mathtype for dim in info.dims.values()]),
(info.output.rows, info.output.cols, info.output.mathtype),
),
[],
)
if valid_viz_modes is None:
return []
return valid_viz_modes
@staticmethod
def to_plotter(
value: typ.Self,
@ -121,7 +115,6 @@ class VizMode(enum.StrEnum):
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]:
return {
VizMode.Hist1D: image_ops.plot_hist_1d,
VizMode.BoxPlot1D: image_ops.plot_box_plot_1d,
VizMode.Curve2D: image_ops.plot_curve_2d,
VizMode.Points2D: image_ops.plot_points_2d,
@ -136,7 +129,6 @@ class VizMode(enum.StrEnum):
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizMode.Hist1D: 'Histogram',
VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points',
@ -164,7 +156,6 @@ class VizTarget(enum.StrEnum):
@staticmethod
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
return {
VizMode.Hist1D: [VizTarget.Plot2D],
VizMode.BoxPlot1D: [VizTarget.Plot2D],
VizMode.Curve2D: [VizTarget.Plot2D],
VizMode.Points2D: [VizTarget.Plot2D],
@ -333,35 +324,20 @@ class VizNode(base.MaxwellSimNode):
## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != {
sym.name for sym in params.symbols if sym.name in info.dim_names
dim.name for dim in params.symbols if dim in info.dims
}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
size=spux.NumberSize1D.Scalar,
mathtype=info.dim_mathtypes[sym.name],
physical_type=info.dim_physical_types[sym.name],
default_min=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].start).is_infinite
else sp.S(0)
),
default_max=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].stop).is_infinite
else sp.S(1)
),
default_steps=50,
)
for sym in params.sorted_symbols
if sym.name in info.dim_names
dim_name: sockets.ExprSocketDef(**expr_info)
for dim_name, expr_info in params.sym_expr_infos(
info, use_range=True
).items()
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
#####################
## - Plotting
## - FlowKind.Value
#####################
@events.computes_output_socket(
'Preview',
@ -375,8 +351,12 @@ class VizNode(base.MaxwellSimNode):
all_loose_input_sockets=True,
)
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
"""Needed for the plot to regenerate in the viewer."""
return ct.FlowSignal.NoFlow
#####################
## - On Show Plot
#####################
@events.on_show_plot(
managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'},
@ -384,13 +364,12 @@ class VizNode(base.MaxwellSimNode):
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
all_loose_input_sockets=True,
stop_propagation=True,
)
def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems
):
self, managed_objs, props, input_sockets, loose_input_sockets
) -> None:
# Retrieve Inputs
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
@ -399,8 +378,6 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
# Invalid Mode | Target
## -> To limit branching, return now if things aren't right.
if (
not has_info
or not has_params
@ -410,53 +387,42 @@ class VizNode(base.MaxwellSimNode):
return
# Compute Ranges for Symbols from Loose Sockets
## -> These are the concrete values of the symbol for plotting.
## -> In a quite nice turn of events, all this is cached lookups.
## -> ...Unless something changed, in which case, well. It changed.
symbol_values = {
sym: (
loose_input_sockets[sym.name]
.realize_array.rescale_to_unit(info.dim_units[sym.name])
.values
symbol_array_values = {
sim_syms: (
loose_input_sockets[sim_syms]
.rescale_to_unit(sim_syms.unit)
.realize_array
)
for sym in params.sorted_symbols
for sim_syms in params.sorted_symbols
}
# Realize Func w/Symbolic Values, Unit System
## -> This gives us the actual plot data!
data = lazy_func.func_jax(
*params.scaled_func_args(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
**params.scaled_func_kwargs(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
)
data = lazy_func.realize(params, symbol_values=symbol_array_values)
# Replace InfoFlow Indices w/Realized Symbolic Ranges
## -> This ensures correct axis scaling.
if params.symbols:
info = info.rescale_dim_idxs(loose_input_sockets)
info = info.replace_dims(symbol_array_values)
# Visualize by-Target
if props['viz_target'] == VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True,
)
match props['viz_target']:
case VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True,
)
if props['viz_target'] == VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image(
data,
colormap=props['colormap'],
bl_select=True,
)
case VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image(
data,
colormap=props['colormap'],
bl_select=True,
)
if props['viz_target'] == VizTarget.PixelsPlane:
raise NotImplementedError
case VizTarget.PixelsPlane:
raise NotImplementedError
if props['viz_target'] == VizTarget.Voxels:
raise NotImplementedError
case VizTarget.Voxels:
raise NotImplementedError
####################

View File

@ -67,7 +67,7 @@ ManagedObjName: typ.TypeAlias = str
PropName: typ.TypeAlias = str
def event_decorator(
def event_decorator( # noqa: PLR0913
event: ct.FlowEvent,
callback_info: EventCallbackInfo | None,
stop_propagation: bool = False,
@ -91,31 +91,42 @@ def event_decorator(
scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
):
"""Returns a decorator for a method of `MaxwellSimNode`, declaring it as able respond to events passing through a node.
"""Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
Should generally be used via a high-level decorator such as `on_value_changed`.
For more about how event methods are actually registered and run, please refer to the documentation of `MaxwellSimNode`.
Parameters:
event: A name describing which event the decorator should respond to.
Set to `return_method.event`
callback_info: A dictionary that provides the caller with additional per-`event` information.
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
props: Set of `props` to compute, then pass to the decorated method.
stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
Other methods defined on the same node will still run.
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
props: Set of `props` to compute, then pass to the decorated method.
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
input_sockets_optional: Whether an input socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
output_sockets_optional: Whether an output socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
Used when the names of the loose input sockets are unknown, but all of their values are needed.
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
Used when the names of the loose output sockets are unknown, but all of their values are needed.
unit_systems: String identifiers under which to load a unit system, made available to the method.
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
Returns:
A decorator, which can be applied to a method of `MaxwellSimNode`.
When a `MaxwellSimNode` subclass initializes, such a decorated method will be picked up on.
When `event` passes through the node, then `callback_info` is used to determine
A decorator, which can be applied to a method of `MaxwellSimNode` to make it an "event method".
"""
req_params = (
{'self'}
@ -375,7 +386,6 @@ def on_value_changed(
)
## TODO: Change name to 'on_output_requested'
def computes_output_socket(
output_socket_name: ct.SocketName | None,
kind: ct.FlowKind = ct.FlowKind.Value,

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
from pathlib import Path
@ -21,7 +22,7 @@ import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
@ -91,6 +92,88 @@ class DataFileImporterNode(base.MaxwellSimNode):
return info
return None
####################
# - Info Guides
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
output_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
output_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'output_physical_type'},
)
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_0_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'dim_0_physical_type'},
)
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerB
)
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_1_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
cb_depends_on={'dim_1_physical_type'},
)
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerC
)
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_2_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
cb_depends_on={'dim_2_physical_type'},
)
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerD
)
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_3_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
cb_depends_on={'dim_3_physical_type'},
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
def dim(self, i: int):
dim_name = getattr(self, f'dim_{i}_name')
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
dim_unit = getattr(self, f'dim_{i}_unit')
return sim_symbols.SimSymbol(
sym_name=dim_name,
mathtype=dim_mathtype,
physical_type=dim_physical_type,
unit=spux.unit_str_to_unit(dim_unit),
)
####################
# - UI
####################
@ -118,7 +201,20 @@ class DataFileImporterNode(base.MaxwellSimNode):
row.label(text=self.file_path.name)
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
pass
"""Draw loaded properties."""
for i in range(len(self.expr_info.dims)):
col = layout.column(align=True)
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text=f'Load Dim {i}')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
####################
# - FlowKind.Array|Func
@ -174,10 +270,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
# Loaded
props={'output_name', 'output_physical_type', 'output_unit'},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Func},
)
def compute_info(self, output_sockets) -> ct.InfoFlow:
def compute_info(self, props, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape.
This currently requires computing the data.
@ -196,26 +294,24 @@ class DataFileImporterNode(base.MaxwellSimNode):
# Deduce Dimensionality
_shape = data.shape
shape = _shape if _shape is not None else ()
dim_names = [f'a{i}' for i in range(len(shape))]
dim_syms = [self.dim(i) for i in range(len(shape))]
# Return InfoFlow
## -> TODO: How to interpret the data should be user-defined.
## -> -- This may require those nice dynamic symbols.
return ct.InfoFlow(
dim_names=dim_names, ## TODO: User
dim_idx={
dim_name: ct.RangeFlow(
start=sp.S(0), ## TODO: User
stop=sp.S(shape[i] - 1), ## TODO: User
steps=shape[dim_names.index(dim_name)],
unit=None, ## TODO: User
dims={
dim_sym: ct.RangeFlow(
start=sp.S(0),
stop=sp.S(shape[i] - 1),
steps=shape[i],
unit=self.dim(i).unit,
)
for i, dim_name in enumerate(dim_names)
for i, dim_sym in enumerate(dim_syms)
},
output_name='_',
output_shape=None,
output_mathtype=spux.MathType.Real, ## TODO: User
output_unit=None, ## TODO: User
output=sim_symbols.SimSymbol(
sym_name=props['output_name'],
mathtype=props['output_mathtype'],
physical_type=props['output_physical_type'],
),
)
return ct.FlowSignal.FlowPending

View File

@ -229,11 +229,11 @@ class DataFileExporterNode(base.MaxwellSimNode):
## -> Only happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != {
sym.name for sym in params.symbols if sym.name in info.dim_names
dim.name for dim in params.symbols if dim in info.dims
}:
self.loose_input_sockets = {
sym_name: sockets.ExprSocketDef(**expr_info)
for sym_name, expr_info in params.sym_expr_infos(info).items()
dim_name: sockets.ExprSocketDef(**expr_info)
for dim_name, expr_info in params.sym_expr_infos(info).items()
}
elif self.loose_input_sockets:

View File

@ -118,13 +118,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
# Symbols
# active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
symbols: frozenset[sp.Symbol] = bl_cache.BLField(frozenset())
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.Expr
)
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
# @property
# def symbols(self) -> set[sp.Symbol]:
# """Current symbols as an unordered set."""
# return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
@property
def symbols(self) -> set[sp.Symbol]:
"""Current symbols as an unordered set."""
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sp.Symbol]:
@ -184,6 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
)
# UI: Info
show_func_ui: bool = bl_cache.BLField(True)
show_info_columns: bool = bl_cache.BLField(False)
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
{InfoDisplayCol.Length, InfoDisplayCol.MathType}
@ -615,35 +618,24 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
"""
output_sim_sym = (
sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
),
)
if self.symbols:
return ct.InfoFlow(
dim_names=[sym.name for sym in self.sorted_symbols],
dim_idx={
sym.name: ct.RangeFlow(
start=-sp.oo if _check_sym_oo(sym) else -sp.zoo,
stop=sp.oo if _check_sym_oo(sym) else sp.zoo,
steps=0,
unit=None, ## Symbols alone are unitless.
)
## TODO: PhysicalTypes for symbols? Or nah?
## TODO: Can we parse some sp.Interval for explicit domains?
## -> We investigated sp.Symbol(..., domain=...).
## -> It's no good. We can't re-extract the interval given to domain.
for sym in self.sorted_symbols
},
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
dims={sim_sym: None for sim_sym in self.active_symbols},
output=output_sim_sym,
)
# Constant
return ct.InfoFlow(
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
)
return ct.InfoFlow(output=output_sim_sym)
####################
# - FlowKind: Capabilities
@ -847,26 +839,31 @@ class ExprBLSocket(base.MaxwellSimSocket):
Uses `draw_value` to draw the base UI
"""
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
if self.show_func_ui:
# Output Name Selector
## -> The name of the output
col.prop(self, self.blfields['output_name'], text='')
# Non-Symbolic: Size/Mathtype Selector
## -> Symbols imply str expr input.
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
if not self.symbols:
row = col.row(align=True)
row.prop(self, self.blfields['size'], text='')
row.prop(self, self.blfields['mathtype'], text='')
# Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='')
# Base UI
## -> Draws the UI appropriate for the above choice of constraints.
self.draw_value(col)
# Non-Symbolic: Size/Mathtype Selector
## -> Symbols imply str expr input.
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
if not self.symbols:
row = col.row(align=True)
row.prop(self, self.blfields['size'], text='')
row.prop(self, self.blfields['mathtype'], text='')
# Symbol UI
## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO
# Base UI
## -> Draws the UI appropriate for the above choice of constraints.
self.draw_value(col)
# Symbol UI
## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO
####################
# - UI: InfoFlow
@ -884,9 +881,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
)
# Dimensions
for dim_name in info.dim_names:
dim_idx = info.dim_idx[dim_name]
grid.label(text=dim_name)
for dim in info.dims:
dim_idx = info.dims[dim]
grid.label(text=dim.name_pretty)
if InfoDisplayCol.Length in self.info_columns:
grid.label(text=str(len(dim_idx)))
if InfoDisplayCol.MathType in self.info_columns:
@ -895,27 +892,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
grid.label(text=spux.sp_to_str(dim_idx.unit))
# Outputs
grid.label(text=info.output_name)
grid.label(text=info.output.name_pretty)
if InfoDisplayCol.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput)
if InfoDisplayCol.MathType in self.info_columns:
grid.label(
text=(
spux.MathType.to_str(info.output_mathtype)
spux.MathType.to_str(info.output.mathtype)
+ (
'ˣ'.join(
[
unicode_superscript(out_axis)
for out_axis in info.output_shape
for out_axis in info.output.shape
]
)
if info.output_shape
if info.output.shape
else ''
)
)
)
if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
####################
@ -929,6 +926,7 @@ class ExprSocketDef(base.SocketDef):
ct.FlowKind.Array,
ct.FlowKind.Func,
] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
# Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -938,10 +936,6 @@ class ExprSocketDef(base.SocketDef):
default_unit: spux.Unit | None = None
default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
@property
def symbols(self) -> set[sp.Symbol]:
return {sim_symbol.sp_symbol for sim_symbol in self.default_symbols}
# FlowKind: Value
default_value: spux.SympyExpr = 0
abs_min: spux.SympyExpr | None = None
@ -954,6 +948,7 @@ class ExprSocketDef(base.SocketDef):
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# UI
show_func_ui: bool = True
show_info_columns: bool = False
####################
@ -1153,6 +1148,14 @@ class ExprSocketDef(base.SocketDef):
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg)
# Coerce from Infinite
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None:
self.default_min = new_bounds[0]
if new_bounds[1] is not None:
@ -1217,13 +1220,14 @@ class ExprSocketDef(base.SocketDef):
####################
def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind
bl_socket.output_name = self.output_name
# Socket Interface
## -> Recall that auto-updates are turned off during init()
bl_socket.size = self.size
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
bl_socket.symbols = self.symbols
bl_socket.active_symbols = self.symbols
# FlowKind.Value
## -> We must take units into account when setting bl_socket.value
@ -1246,6 +1250,7 @@ class ExprSocketDef(base.SocketDef):
)
# UI
bl_socket.show_func_ui = self.show_func_ui
bl_socket.show_info_columns = self.show_info_columns
# Info Draw

View File

@ -61,7 +61,6 @@ SympyType = (
class MathType(enum.StrEnum):
"""Type identifiers that encompass common sets of mathematical objects."""
Bool = enum.auto()
Integer = enum.auto()
Rational = enum.auto()
Real = enum.auto()
@ -77,8 +76,6 @@ class MathType(enum.StrEnum):
return MathType.Rational
if MathType.Integer in mathtypes:
return MathType.Integer
if MathType.Bool in mathtypes:
return MathType.Bool
msg = f"Can't combine mathtypes {mathtypes}"
raise ValueError(msg)
@ -88,7 +85,6 @@ class MathType(enum.StrEnum):
return (
other
in {
MT.Bool: [MT.Bool],
MT.Integer: [MT.Integer],
MT.Rational: [MT.Integer, MT.Rational],
MT.Real: [MT.Integer, MT.Rational, MT.Real],
@ -98,11 +94,9 @@ class MathType(enum.StrEnum):
def coerce_compatible_pyobj(
self, pyobj: bool | int | Fraction | float | complex
) -> bool | int | Fraction | float | complex:
) -> int | Fraction | float | complex:
MT = MathType
match self:
case MT.Bool:
return pyobj
case MT.Integer:
return int(pyobj)
case MT.Rational if isinstance(pyobj, int):
@ -123,8 +117,6 @@ class MathType(enum.StrEnum):
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
)
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
return MathType.Bool
if sp_obj.is_integer:
return MathType.Integer
if sp_obj.is_rational:
@ -146,7 +138,6 @@ class MathType(enum.StrEnum):
@staticmethod
def from_pytype(dtype: type) -> type:
return {
bool: MathType.Bool,
int: MathType.Integer,
Fraction: MathType.Rational,
float: MathType.Real,
@ -166,7 +157,6 @@ class MathType(enum.StrEnum):
def pytype(self) -> type:
MT = MathType
return {
MT.Bool: bool,
MT.Integer: int,
MT.Rational: Fraction,
MT.Real: float,
@ -177,17 +167,25 @@ class MathType(enum.StrEnum):
def symbolic_set(self) -> type:
MT = MathType
return {
MT.Bool: sp.Set([sp.S(False), sp.S(True)]),
MT.Integer: sp.Integers,
MT.Rational: sp.Rationals,
MT.Real: sp.Reals,
MT.Complex: sp.Complexes,
}[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
return {
MT.Integer: sp.Symbol('a', integer=True),
MT.Rational: sp.Symbol('a', rational=True),
MT.Real: sp.Symbol('a', real=True),
MT.Complex: sp.Symbol('a', complex=True),
}[self]
@staticmethod
def to_str(value: typ.Self) -> type:
return {
MathType.Bool: 'T|F',
MathType.Integer: '',
MathType.Rational: '',
MathType.Real: '',
@ -212,6 +210,9 @@ class MathType(enum.StrEnum):
)
####################
# - Size: 1D
####################
class NumberSize1D(enum.StrEnum):
"""Valid 1D-constrained shape."""
@ -278,6 +279,20 @@ class NumberSize1D(enum.StrEnum):
(4, 1): NS.Vec4,
}[shape]
@property
def rows(self):
NS = NumberSize1D
return {
NS.Scalar: 1,
NS.Vec2: 2,
NS.Vec3: 3,
NS.Vec4: 4,
}[self]
@property
def cols(self):
return 1
@property
def shape(self):
NS = NumberSize1D
@ -297,6 +312,30 @@ def symbol_range(sym: sp.Symbol) -> str:
)
####################
# - Symbol Sizes
####################
class SimpleSize2D(enum.StrEnum):
"""Simple subset of sizes for rank-2 tensors."""
Scalar = enum.auto()
# Vectors
Vec2 = enum.auto() ## 2x1
Vec3 = enum.auto() ## 3x1
Vec4 = enum.auto() ## 4x1
# Covectors
CoVec2 = enum.auto() ## 1x2
CoVec3 = enum.auto() ## 1x3
CoVec4 = enum.auto() ## 1x4
# Square Matrices
Mat22 = enum.auto() ## 2x2
Mat33 = enum.auto() ## 3x3
Mat44 = enum.auto() ## 4x4
####################
# - Unit Dimensions
####################
@ -382,6 +421,8 @@ UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
####################
# - Expr Analysis: Units
@ -907,10 +948,6 @@ class PhysicalType(enum.StrEnum):
LumIntensity = enum.auto()
LumFlux = enum.auto()
Illuminance = enum.auto()
# Optics
OrdinaryWaveVector = enum.auto()
AngularWaveVector = enum.auto()
PoyntingVector = enum.auto()
@functools.cached_property
def unit_dim(self):
@ -956,10 +993,6 @@ class PhysicalType(enum.StrEnum):
PT.LumIntensity: Dims.luminous_intensity,
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
# Optics
PT.OrdinaryWaveVector: Dims.frequency,
PT.AngularWaveVector: Dims.angle * Dims.frequency,
PT.PoyntingVector: Dims.power / Dims.length**2,
}[self]
@functools.cached_property
@ -1196,10 +1229,6 @@ class PhysicalType(enum.StrEnum):
PT.HField: [None, (2,), (3,)],
# Luminal
PT.LumFlux: [None, (2,), (3,)],
# Optics
PT.OrdinaryWaveVector: [None, (2,), (3,)],
PT.AngularWaveVector: [None, (2,), (3,)],
PT.PoyntingVector: [None, (2,), (3,)],
}
return overrides.get(self, [None])
@ -1222,7 +1251,6 @@ class PhysicalType(enum.StrEnum):
- **Charge**: Generally, it is real.
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
- **Poynting**: The imaginary part represents the oscillation in the power flux over time.
"""
MT = MathType
@ -1249,10 +1277,6 @@ class PhysicalType(enum.StrEnum):
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
# Luminal
# Optics
PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power
}
return overrides.get(self, [MT.Real])
@ -1323,10 +1347,6 @@ UNITS_SI: UnitSystem = {
_PT.LumIntensity: spu.candela,
_PT.LumFlux: lumen,
_PT.Illuminance: spu.lux,
# Optics
_PT.OrdinaryWaveVector: spu.hertz,
_PT.AngularWaveVector: spu.radian * spu.hertz,
_PT.PoyntingVector: spu.watt / spu.meter**2,
}
@ -1380,15 +1400,20 @@ def sympy_to_python(
####################
# - Convert to Unit System
####################
def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
if unit_system is None:
return sp_obj
return spu.convert_to(
sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
)
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
@ -1397,11 +1422,13 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
Notes:
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
"""
if unit_system is None:
return sp_obj.subs(UNIT_TO_1)
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array:
"""Convert an expression to the units of a given unit system, then strip all units of the unit system.

View File

@ -29,11 +29,13 @@ import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg
import matplotlib.figure
import matplotlib.style as mplstyle
import seaborn as sns
from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger
mplstyle.use('fast') ## TODO: Does this do anything?
sns.set_theme()
log = logger.get(__name__)
@ -149,125 +151,98 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
####################
# - Plotters
####################
# () ->
def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None:
y_name = info.output_name
y_unit = info.output_unit
ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# () ->
def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
y_name = info.output_name
y_unit = info.output_unit
x_sym = info.last_dim
y_sym = info.output
ax.boxplot(data)
ax.set_title('Box Plot')
ax.set_xlabel(f'{x_name}')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
ax.boxplot([data])
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_sym = info.last_dim
y_sym = info.output
p = ax.bar(info.dims[x_sym], data)
ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
# () ->
def plot_curve_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
times = [time.perf_counter()]
x_sym = info.last_dim
y_sym = info.output
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data)
times.append(time.perf_counter() - times[0])
ax.set_title('2D Curve')
times.append(time.perf_counter() - times[0])
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
times.append(time.perf_counter() - times[0])
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
times.append(time.perf_counter() - times[0])
# log.critical('Timing of Curve2D: %s', str(times))
ax.plot(info.dims[x_sym].realize_array.values, data)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
def plot_points_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
x_sym = info.last_dim
y_sym = info.output
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
ax.set_title('2D Points')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
ax.scatter(x_sym.realize_array.values, data)
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
# (, ) ->
def plot_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
x_sym = info.first_dim
y_sym = info.output
for category in range(data.shape[1]):
ax.plot(info.dim_idx_arrays[0], data[:, category])
for i, category in enumerate(info.dims[info.last_dim]):
ax.plot(info.dims[x_sym], data[:, i], label=category)
ax.set_title('2D Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()
def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
x_sym = info.first_dim
y_sym = info.output
shared_x_idx = info.dim_idx_arrays[0]
shared_x_idx = info.dims[info.last_dim]
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
ax.set_title('2D Filled Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()
# (, ) ->
def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.dim_names[1]
y_unit = info.dim_units[y_name]
x_sym = info.first_dim
y_sym = info.last_dim
c_sym = info.output
heatmap = ax.imshow(data, aspect='auto', interpolation='none')
# ax.figure.colorbar(heatmap, ax=ax)
ax.set_title('Heatmap')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
heatmap = ax.imshow(data, aspect='equal', interpolation='none')
ax.figure.colorbar(heatmap, cax=ax)
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()

View File

@ -18,26 +18,67 @@ import dataclasses
import enum
import sys
import typing as typ
from fractions import Fraction
import sympy as sp
from . import extra_sympy_units as spux
int_min = -(2**64)
int_max = 2**64
float_min = sys.float_info.min
float_max = sys.float_info.max
####################
# - Simulation Symbols
# - Simulation Symbol Names
####################
class SimSymbolName(enum.StrEnum):
# Lower
LowerA = enum.auto()
LowerB = enum.auto()
LowerC = enum.auto()
LowerD = enum.auto()
LowerI = enum.auto()
LowerT = enum.auto()
LowerX = enum.auto()
LowerY = enum.auto()
LowerZ = enum.auto()
# Physics
# Fields
Ex = enum.auto()
Ey = enum.auto()
Ez = enum.auto()
Hx = enum.auto()
Hy = enum.auto()
Hz = enum.auto()
Er = enum.auto()
Etheta = enum.auto()
Ephi = enum.auto()
Hr = enum.auto()
Htheta = enum.auto()
Hphi = enum.auto()
# Optics
Wavelength = enum.auto()
Frequency = enum.auto()
Flux = enum.auto()
PermXX = enum.auto()
PermYY = enum.auto()
PermZZ = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
# Generic
Expr = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name.
@ -50,27 +91,6 @@ class SimSymbolName(enum.StrEnum):
"""
return SimSymbolName(v).name
@property
def name(self) -> str:
SSN = SimSymbolName
return {
SSN.LowerA: 'a',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
@staticmethod
def to_icon(_: typ.Self) -> str:
"""Convert the enum value to a Blender icon.
@ -83,6 +103,75 @@ class SimSymbolName(enum.StrEnum):
"""
return ''
####################
# - Computed Properties
####################
@property
def name(self) -> str:
SSN = SimSymbolName
return {
# Lower
SSN.LowerA: 'a',
SSN.LowerB: 'b',
SSN.LowerC: 'c',
SSN.LowerD: 'd',
SSN.LowerI: 'i',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
# Fields
SSN.Ex: 'Ex',
SSN.Ey: 'Ey',
SSN.Ez: 'Ez',
SSN.Hx: 'Hx',
SSN.Hy: 'Hy',
SSN.Hz: 'Hz',
SSN.Er: 'Ex',
SSN.Etheta: 'Ey',
SSN.Ephi: 'Ez',
SSN.Hr: 'Hx',
SSN.Htheta: 'Hy',
SSN.Hphi: 'Hz',
# Optics
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
SSN.Flux: 'flux',
SSN.PermXX: 'eps_xx',
SSN.PermYY: 'eps_yy',
SSN.PermZZ: 'eps_zz',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
# Generic
SSN.Expr: 'expr',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
####################
# - Simulation Symbol
####################
def mk_interval(
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
interval_inf: tuple[bool, bool],
interval_closed: tuple[bool, bool],
unit_factor: typ.Literal[1] | spux.Unit,
) -> sp.Interval:
"""Create a symbolic interval from the tuples (and unit) defining it."""
return sp.Interval(
start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
left_open=(True if interval_inf[0] else not interval_closed[0]),
right_open=(True if interval_inf[1] else not interval_closed[1]),
)
@dataclasses.dataclass(kw_only=True, frozen=True)
class SimSymbol:
@ -94,66 +183,145 @@ class SimSymbol:
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
"""
sim_node_name: SimSymbolName = SimSymbolName.LowerX
sym_name: SimSymbolName
mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
## TODO: Shape/size support? Incl. MatrixSymbol.
# Units
## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way.
unit: spux.Unit | None = None
# Domain
interval_finite: tuple[float, float] = (0, 1)
# Size
## -> All SimSymbol sizes are "2D", but interpreted by convention.
## -> 1x1: "Scalar".
## -> nx1: "Vector".
## -> 1xn: "Covector".
## -> nxn: "Matrix".
rows: int = 1
cols: int = 1
# Scalar Domain: "Interval"
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI.
interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0, 1)
interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False)
interval_finite_im: tuple[float, float] = (0, 1)
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
####################
# - Properties
####################
@property
def name(self) -> str:
return self.sim_node_name.name
"""Usable name for the symbol."""
return self.sym_name.name
@property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
return self.sym_name.name_pretty
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
@property
def plot_label(self) -> str:
"""Pretty plot-oriented label."""
return f'{self.name_pretty}' + (
f'({self.unit})' if self.unit is not None else ''
)
@property
def unit_factor(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return self.unit if self.unit is not None else sp.S(1)
@property
def shape(self) -> tuple[int, ...]:
match (self.rows, self.cols):
case (1, 1):
return ()
case (_, 1):
return (self.rows,)
case (1, _):
return (1, self.rows)
case (_, _):
return (self.rows, self.cols)
@property
def domain(self) -> sp.Interval | sp.Set:
"""Return the domain of valid values for the symbol.
"""Return the scalar domain of valid values for each element of the symbol.
For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
This interval **must** have the property`start <= stop`.
Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
"""
if self.mathtype in [
spux.MathType.Integer,
spux.MathType.Rational,
spux.MathType.Real,
]:
return sp.Interval(
start=self.interval_finite[0] if not self.interval_inf[0] else -sp.oo,
end=self.interval_finite[1] if not self.interval_inf[1] else sp.oo,
left_open=(
True if self.interval_inf[0] else not self.interval_closed[0]
),
right_open=(
True if self.interval_inf[1] else not self.interval_closed[1]
),
)
match self.mathtype:
case spux.MathType.Integer:
return mk_interval(
self.interval_finite_z,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
return self.mathtype.symbolic_set
case spux.MathType.Rational:
return mk_interval(
Fraction(*self.interval_finite_q),
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Real:
return mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Complex:
return (
mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
),
mk_interval(
self.interval_finite_im,
self.interval_inf_im,
self.interval_closed_im,
self.unit_factor,
),
)
####################
# - Properties
####################
@property
def sp_symbol(self) -> sp.Symbol:
"""Return a symbolic variable corresponding to this `SimSymbol`.
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
However, the assumptions system alone is rather limited, and implementations should therefore also strongly consider transporting `SimSymbols` directly, instead of `sp.Symbol`.
This allows making use of other properties like `self.domain`, when appropriate.
- **MathType**: Depending on `self.mathtype`.
- **Positive/Negative**: Depending on `self.domain`.
- **Nonzero**: Depending on `self.domain`, including open/closed boundary specifications.
Notes:
**The assumptions system is rather limited**, and implementations should strongly consider transporting `SimSymbols` instead of `sp.Symbol`.
This allows tracking ex. the valid interval domain for a symbol.
"""
# MathType Domain Constraint
## -> We must feed the assumptions system.
# MathType Assumption
mathtype_kwargs = {}
match self.mathtype:
case spux.MathType.Integer:
@ -165,53 +333,138 @@ class SimSymbol:
case spux.MathType.Complex:
mathtype_kwargs |= {'complex': True}
# Interval Constraints
if isinstance(self.domain, sp.Interval):
# Assumption: Non-Zero
if (
(
self.domain.left == 0
and self.domain.left_open
or self.domain.right == 0
and self.domain.right_open
)
or self.domain.left > 0
or self.domain.right < 0
):
mathtype_kwargs |= {'nonzero': True}
# Non-Zero Assumption
if (
(
self.domain.left == 0
and self.domain.left_open
or self.domain.right == 0
and self.domain.right_open
)
or self.domain.left > 0
or self.domain.right < 0
):
mathtype_kwargs |= {'nonzero': True}
# Assumption: Positive/Negative
if self.domain.left >= 0:
mathtype_kwargs |= {'positive': True}
elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True}
# Positive/Negative Assumption
if self.domain.left >= 0:
mathtype_kwargs |= {'positive': True}
elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True}
# Construct the Symbol
return sp.Symbol(self.sim_node_name.name, **mathtype_kwargs)
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
####################
# - Operations
####################
def update(self, **kwargs) -> typ.Self:
def get_attr(attr: str):
_notfound = 'notfound'
if kwargs.get(attr, _notfound) is _notfound:
return getattr(self, attr)
return kwargs[attr]
return SimSymbol(
sym_name=get_attr('sym_name'),
mathtype=get_attr('mathtype'),
physical_type=get_attr('physical_type'),
unit=get_attr('unit'),
rows=get_attr('rows'),
cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_q'),
interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'),
interval_inf_im=get_attr('interval_inf_im'),
interval_closed_im=get_attr('interval_closed_im'),
)
def set_size(self, rows: int, cols: int) -> typ.Self:
return SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=rows,
cols=cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
)
####################
# - Common Sim Symbols
####################
class CommonSimSymbol(enum.StrEnum):
"""A set of pre-defined symbols that might commonly be used in the context of physical simulation.
"""Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
Each entry maps directly to a particular `SimSymbol`.
The enum is compatible with `BLField`, making it easy to declare a UI-driven dropdown of symbols that behave as expected.
The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
Attributes:
X:
Time: A symbol representing a real-valued wavelength.
Wavelength: A symbol representing a real-valued wavelength.
Implicitly, this symbol often represents "vacuum wavelength" in particular.
Wavelength: A symbol representing a real-valued frequency.
Generally, this is the non-angular frequency.
"""
X = enum.auto()
Index = enum.auto()
# Space|Time
SpaceX = enum.auto()
SpaceY = enum.auto()
SpaceZ = enum.auto()
AngR = enum.auto()
AngTheta = enum.auto()
AngPhi = enum.auto()
DirX = enum.auto()
DirY = enum.auto()
DirZ = enum.auto()
Time = enum.auto()
# Fields
FieldEx = enum.auto()
FieldEy = enum.auto()
FieldEz = enum.auto()
FieldHx = enum.auto()
FieldHy = enum.auto()
FieldHz = enum.auto()
FieldEr = enum.auto()
FieldEtheta = enum.auto()
FieldEphi = enum.auto()
FieldHr = enum.auto()
FieldHtheta = enum.auto()
FieldHphi = enum.auto()
# Optics
Wavelength = enum.auto()
Frequency = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
Flux = enum.auto()
WaveVecX = enum.auto()
WaveVecY = enum.auto()
WaveVecZ = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name.
@ -222,7 +475,7 @@ class CommonSimSymbol(enum.StrEnum):
Returns:
A human-friendly name corresponding to the enum value.
"""
return CommonSimSymbol(v).sim_symbol_name.name
return CommonSimSymbol(v).name
@staticmethod
def to_icon(_: typ.Self) -> str:
@ -241,55 +494,125 @@ class CommonSimSymbol(enum.StrEnum):
####################
@property
def name(self) -> str:
return self.sim_symbol.name
@property
def sim_symbol_name(self) -> str:
SSN = SimSymbolName
CSS = CommonSimSymbol
return {
CSS.X: SSN.LowerX,
CSS.Index: SSN.LowerI,
# Space|Time
CSS.SpaceX: SSN.LowerX,
CSS.SpaceY: SSN.LowerY,
CSS.SpaceZ: SSN.LowerZ,
CSS.AngR: SSN.LowerR,
CSS.AngTheta: SSN.LowerTheta,
CSS.AngPhi: SSN.LowerPhi,
CSS.DirX: SSN.LowerX,
CSS.DirY: SSN.LowerY,
CSS.DirZ: SSN.LowerZ,
CSS.Time: SSN.LowerT,
CSS.Wavelength: SSN.Wavelength,
# Fields
CSS.FieldEx: SSN.Ex,
CSS.FieldEy: SSN.Ey,
CSS.FieldEz: SSN.Ez,
CSS.FieldHx: SSN.Hx,
CSS.FieldHy: SSN.Hy,
CSS.FieldHz: SSN.Hz,
CSS.FieldEr: SSN.Er,
CSS.FieldHr: SSN.Hr,
# Optics
CSS.Frequency: SSN.Frequency,
CSS.Wavelength: SSN.Wavelength,
CSS.DiffOrderX: SSN.DiffOrderX,
CSS.DiffOrderY: SSN.DiffOrderY,
}[self]
@property
def sim_symbol(self) -> SimSymbol:
def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
"""Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
CSS = CommonSimSymbol
# Space
sym_space = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Length,
unit=unit,
)
sym_ang = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Angle,
unit=unit,
)
# Fields
def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
return SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.EField
if eh == 'e'
else spux.PhysicalType.HField,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf_re=(False, True),
interval_closed_re=(True, False),
interval_finite_im=(sys.float_info.min, sys.float_info.max),
interval_inf_im=(True, True),
)
return {
CSS.X: SimSymbol(
sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.NonPhysical,
## TODO: Unit of Picosecond
interval_finite=(sys.float_info.min, sys.float_info.max),
interval_inf=(True, True),
interval_closed=(False, False),
),
CSS.Time: SimSymbol(
sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Time,
## TODO: Unit of Picosecond
interval_finite=(0, sys.float_info.max),
CSS.Index: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
interval_finite_z=(0, 2**64),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Space|Time
CSS.SpaceX: sym_space,
CSS.SpaceY: sym_space,
CSS.SpaceZ: sym_space,
CSS.AngR: sym_space,
CSS.AngTheta: sym_ang,
CSS.AngPhi: sym_ang,
CSS.Time: SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Time,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Fields
CSS.FieldEx: sym_field('e'),
CSS.FieldEy: sym_field('e'),
CSS.FieldEz: sym_field('e'),
CSS.FieldHx: sym_field('h'),
CSS.FieldHy: sym_field('h'),
CSS.FieldHz: sym_field('h'),
CSS.FieldEr: sym_field('e'),
CSS.FieldEtheta: sym_field('e'),
CSS.FieldEphi: sym_field('e'),
CSS.FieldHr: sym_field('h'),
CSS.FieldHtheta: sym_field('h'),
CSS.FieldHphi: sym_field('h'),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
# Optics
CSS.Wavelength: SimSymbol(
sim_node_name=self.sim_symbol_name,
sym_name=self.name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
## TODO: Unit of Picosecond
unit=unit,
interval_finite=(0, sys.float_info.max),
interval_inf=(False, True),
interval_closed=(False, False),
),
CSS.Frequency: SimSymbol(
sim_node_name=self.sim_symbol_name,
sym_name=self.name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq,
unit=unit,
interval_finite=(0, sys.float_info.max),
interval_inf=(False, True),
interval_closed=(False, False),
@ -298,9 +621,33 @@ class CommonSimSymbol(enum.StrEnum):
####################
# - Selected Direct Access
# - Selected Direct-Access to SimSymbols
####################
x = CommonSimSymbol.X.sim_symbol
idx = CommonSimSymbol.Index.sim_symbol
t = CommonSimSymbol.Time.sim_symbol
wl = CommonSimSymbol.Wavelength.sim_symbol
freq = CommonSimSymbol.Frequency.sim_symbol
space_x = CommonSimSymbol.SpaceX.sim_symbol
space_y = CommonSimSymbol.SpaceY.sim_symbol
space_z = CommonSimSymbol.SpaceZ.sim_symbol
dir_x = CommonSimSymbol.DirX.sim_symbol
dir_y = CommonSimSymbol.DirY.sim_symbol
dir_z = CommonSimSymbol.DirZ.sim_symbol
ang_r = CommonSimSymbol.AngR.sim_symbol
ang_theta = CommonSimSymbol.AngTheta.sim_symbol
ang_phi = CommonSimSymbol.AngPhi.sim_symbol
field_ex = CommonSimSymbol.FieldEx.sim_symbol
field_ey = CommonSimSymbol.FieldEy.sim_symbol
field_ez = CommonSimSymbol.FieldEz.sim_symbol
field_hx = CommonSimSymbol.FieldHx.sim_symbol
field_hy = CommonSimSymbol.FieldHx.sim_symbol
field_hz = CommonSimSymbol.FieldHx.sim_symbol
flux = CommonSimSymbol.Flux.sim_symbol
diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol