refactor: end-of-day commit (sim symbol flow for data import/export & inverse design)

main
Sofus Albert Høgsbro Rose 2024-05-21 22:57:56 +02:00
parent dccf952ad3
commit 353a2c997e
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
24 changed files with 2058 additions and 1710 deletions

View File

@ -27,6 +27,7 @@ dependencies = [
#"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen #"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen
"certifi==2021.10.8", "certifi==2021.10.8",
"polars>=0.20.26", "polars>=0.20.26",
"seaborn[stats]>=0.13.2",
] ]
## When it comes to dev-dep conflicts: ## When it comes to dev-dep conflicts:
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them. ## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.

View File

@ -81,6 +81,7 @@ locket==1.0.0
markupsafe==2.1.5 markupsafe==2.1.5
# via jinja2 # via jinja2
matplotlib==3.8.3 matplotlib==3.8.3
# via seaborn
# via tidy3d # via tidy3d
ml-dtypes==0.4.0 ml-dtypes==0.4.0
# via jax # via jax
@ -102,8 +103,11 @@ numpy==1.24.3
# via ml-dtypes # via ml-dtypes
# via numba # via numba
# via opt-einsum # via opt-einsum
# via patsy
# via scipy # via scipy
# via seaborn
# via shapely # via shapely
# via statsmodels
# via tidy3d # via tidy3d
# via trimesh # via trimesh
# via xarray # via xarray
@ -114,11 +118,16 @@ packaging==24.0
# via dask # via dask
# via h5netcdf # via h5netcdf
# via matplotlib # via matplotlib
# via statsmodels
# via xarray # via xarray
pandas==2.2.1 pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray # via xarray
partd==1.4.1 partd==1.4.1
# via dask # via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0 pillow==10.2.0
# via matplotlib # via matplotlib
platformdirs==4.2.1 platformdirs==4.2.1
@ -167,13 +176,19 @@ s3transfer==0.5.2
scipy==1.12.0 scipy==1.12.0
# via jax # via jax
# via jaxlib # via jaxlib
# via seaborn
# via statsmodels
# via tidy3d # via tidy3d
seaborn==0.13.2
setuptools==69.5.1 setuptools==69.5.1
# via nodeenv # via nodeenv
shapely==2.0.3 shapely==2.0.3
# via tidy3d # via tidy3d
six==1.16.0 six==1.16.0
# via patsy
# via python-dateutil # via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12 sympy==1.12
termcolor==2.4.0 termcolor==2.4.0
# via commitizen # via commitizen

View File

@ -59,6 +59,7 @@ llvmlite==0.42.0
locket==1.0.0 locket==1.0.0
# via partd # via partd
matplotlib==3.8.3 matplotlib==3.8.3
# via seaborn
# via tidy3d # via tidy3d
ml-dtypes==0.4.0 ml-dtypes==0.4.0
# via jax # via jax
@ -78,8 +79,11 @@ numpy==1.24.3
# via ml-dtypes # via ml-dtypes
# via numba # via numba
# via opt-einsum # via opt-einsum
# via patsy
# via scipy # via scipy
# via seaborn
# via shapely # via shapely
# via statsmodels
# via tidy3d # via tidy3d
# via trimesh # via trimesh
# via xarray # via xarray
@ -89,11 +93,16 @@ packaging==24.0
# via dask # via dask
# via h5netcdf # via h5netcdf
# via matplotlib # via matplotlib
# via statsmodels
# via xarray # via xarray
pandas==2.2.1 pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray # via xarray
partd==1.4.1 partd==1.4.1
# via dask # via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0 pillow==10.2.0
# via matplotlib # via matplotlib
polars==0.20.26 polars==0.20.26
@ -132,11 +141,17 @@ s3transfer==0.5.2
scipy==1.12.0 scipy==1.12.0
# via jax # via jax
# via jaxlib # via jaxlib
# via seaborn
# via statsmodels
# via tidy3d # via tidy3d
seaborn==0.13.2
shapely==2.0.3 shapely==2.0.3
# via tidy3d # via tidy3d
six==1.16.0 six==1.16.0
# via patsy
# via python-dateutil # via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12 sympy==1.12
tidy3d==2.6.3 tidy3d==2.6.3
toml==0.10.2 toml==0.10.2

View File

@ -29,9 +29,12 @@ from blender_maxwell.utils import logger
log = logger.get(__name__) log = logger.get(__name__)
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class ArrayFlow: class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit. """A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
Attributes: Attributes:
values: An ND array-like object of arbitrary numerical type. values: An ND array-like object of arbitrary numerical type.
@ -44,13 +47,97 @@ class ArrayFlow:
is_sorted: bool = False is_sorted: bool = False
####################
# - Computed Properties
####################
@property
def is_symbolic(self) -> bool:
"""Always False, as ArrayFlows are never unrealized."""
return False
def __len__(self) -> int: def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values) return len(self.values)
@functools.cached_property @functools.cached_property
def mathtype(self) -> spux.MathType: def mathtype(self) -> spux.MathType:
"""Deduce the `spux.MathType` of the first element of the contained array.
This is generally a heuristic, but because `jax` enforces homogeneous arrays, this is actually a well-defined approach.
"""
return spux.MathType.from_pytype(type(self.values.item(0))) return spux.MathType.from_pytype(type(self.values.item(0)))
@functools.cached_property
def physical_type(self) -> spux.MathType:
"""Deduce the `spux.PhysicalType` of the unit."""
return spux.PhysicalType.from_unit(self.unit)
####################
# - Array Features
####################
@property
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
"""Standardized access to `self.values`."""
return self.values
@functools.cached_property
def shape(self) -> int:
"""Shape of the contained array."""
return self.values.shape
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
"""Implement indexing and slicing in a sane way.
- **Integer Index**: For scalar output, return a `sympy` expression of the scalar multiplied by the unit, else just a sympy expression of the value.
- **Slice**: Slice the internal array directly, and wrap the result in a new `ArrayFlow`.
"""
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
if isinstance(subscript, int):
value = self.values[subscript]
if len(value.shape) == 0:
return value * self.unit if self.unit is not None else sp.S(value)
return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
raise NotImplementedError
####################
# - Methods
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
"""Apply an order-preserving function to each element of the array, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
# Compile JAX-Compatible Rescale Function
a = self.mathtype.sp_symbol_a
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value. """Find the index of the value that is closest to the given value.
@ -88,56 +175,26 @@ class ArrayFlow:
return right_idx return right_idx
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self: ####################
if self.unit is not None: # - Unit Transforms
return ArrayFlow( ####################
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted def correct_unit(self, unit: spux.Unit) -> typ.Self:
) """Simply replace the existing unit with the given one.
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"' Parameters:
raise ValueError(msg) corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
def rescale_to_unit(self, unit: spu.Quantity | None) -> typ.Self: def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
## TODO: Cache by unit would be a very nice speedup for Viz node. """Rescale the `ArrayFlow` to be expressed in the given unit.
if self.unit is not None:
return ArrayFlow(
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
unit=unit,
is_sorted=self.is_sorted,
)
if unit is None: Parameters:
return self corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' """
raise ValueError(msg) return self.rescale(lambda v: v, new_unit=new_unit)
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
# Compile JAX-Compatible Rescale Function
a = sp.Symbol('a')
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a * self.unit)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
raise NotImplementedError raise NotImplementedError

View File

@ -0,0 +1,36 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from . import FlowKind
class ExprInfo(typ.TypedDict):
active_kind: FlowKind
size: spux.NumberSize1D
mathtype: spux.MathType
physical_type: spux.PhysicalType
# Value
default_value: spux.SympyExpr
# Range
default_min: spux.SympyExpr
default_max: spux.SympyExpr
default_steps: int

View File

@ -19,6 +19,7 @@ import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty
log = logger.get(__name__) log = logger.get(__name__)
@ -51,16 +52,71 @@ class FlowKind(enum.StrEnum):
Capabilities = enum.auto() Capabilities = enum.auto()
# Values # Values
Value = enum.auto() Value = enum.auto() ## 'value'
Array = enum.auto() Array = enum.auto() ## 'array'
# Lazy # Lazy
Func = enum.auto() Func = enum.auto() ## 'lazy_func'
Range = enum.auto() Range = enum.auto() ## 'lazy_range'
# Auxiliary # Auxiliary
Params = enum.auto() Params = enum.auto() ## 'params'
Info = enum.auto() Info = enum.auto() ## 'info'
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
# Values
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
# Lazy
FlowKind.Range: 'Range',
FlowKind.Func: 'Func',
# Auxiliary
FlowKind.Params: 'Params',
FlowKind.Info: 'Info',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
####################
# - Static Properties
####################
@staticproperty
def active_kinds() -> list[typ.Self]:
"""Return a list of `FlowKind`s that are able to be considered "active".
"Active" `FlowKind`s are considered the primary data type of a socket's flow.
For example, for sockets to be linkeable, their active `FlowKind` must generally match.
"""
return [
FlowKind.Value,
FlowKind.Array,
FlowKind.Range,
FlowKind.Func,
]
@property
def socket_shape(self) -> str:
"""Return the socket shape associated with this `FlowKind`.
**ONLY** valid for `FlowKind`s that can be considered "active".
Raises:
ValueError: If this `FlowKind` cannot ever be considered "active".
"""
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
#################### ####################
# - Class Methods # - Class Methods
@ -69,7 +125,7 @@ class FlowKind(enum.StrEnum):
def scale_to_unit_system( def scale_to_unit_system(
cls, cls,
kind: typ.Self, kind: typ.Self,
flow_obj, flow_obj: spux.SympyExpr,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem,
): ):
# log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj)) # log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
@ -87,43 +143,3 @@ class FlowKind(enum.StrEnum):
msg = 'Tried to scale unknown kind' msg = 'Tried to scale unknown kind'
raise ValueError(msg) raise ValueError(msg)
####################
# - Computed
####################
@property
def flow_kind(self) -> str:
return {
FlowKind.Value: FlowKind.Value,
FlowKind.Array: FlowKind.Array,
FlowKind.Func: FlowKind.Func,
FlowKind.Range: FlowKind.Range,
}[self]
@property
def socket_shape(self) -> str:
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
####################
# - Blender Enum
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
FlowKind.Range: 'Range',
FlowKind.Func: 'Func',
FlowKind.Params: 'Parameters',
FlowKind.Info: 'Information',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''

View File

@ -18,246 +18,247 @@ import dataclasses
import functools import functools
import typing as typ import typing as typ
import jax
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow from .array import ArrayFlow
from .lazy_range import RangeFlow from .lazy_range import RangeFlow
log = logger.get(__name__) log = logger.get(__name__)
LabelArray: typ.TypeAlias = list[str]
# IndexArray: Identifies Discrete Dimension Values
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
## -> RangeFlow (rat|real): Index by unrealized array scaled between boundaries.
## -> LabelArray (int): For int-index arrays, interpret using these labels.
## -> None: Non-Discrete/unrealized indexing; use 'dim.domain'.
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow: class InfoFlow:
#################### """Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
# - Covariant Input
####################
dim_names: list[str] = dataclasses.field(default_factory=list)
dim_idx: dict[str, ArrayFlow | RangeFlow] = dataclasses.field(
default_factory=dict
) ## TODO: Rename to dim_idxs
@functools.cached_property Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
def dim_has_coords(self) -> dict[str, int]:
return {
dim_name: not (
isinstance(dim_idx, RangeFlow)
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property # From Arrays to Tensors
def dim_lens(self) -> dict[str, int]: The best way to illustrate how it works is to specify how raw-array concepts map to an array described by an `InfoFlow`:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property - **Index**: In raw arrays, the "index" is generally constrained to an integer ring, and has no semantic meaning.
def dim_mathtypes(self) -> dict[str, spux.MathType]: **(Covariant) Dimension**: The "dimension" is an named "index array", which assigns each integer index a **scalar** value of particular mathematical type, name, and unit (if not unitless).
return { - **Value**: In raw arrays, the "value" is some particular computational type, or another raw array.
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() **(Contravariant) Output**: The "output" is a strictly named, sized object that can only be produced
}
@functools.cached_property In essence, `InfoFlow` allows us to treat raw data as a tensor, then operate on its dimensionality as split into parts whose transform varies _with_ the output (aka. a _covariant_ index), and parts whose transform varies _against_ the output (aka. _contravariant_ value).
def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property ## Benefits
def dim_physical_types(self) -> dict[str, spux.PhysicalType]: The reasons to do this are numerous:
return {
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property - **Clarity**: Using `InfoFlow`, it's easy to understand what the data is, and what can be done to it, making it much easier to implement complex operations in math nodes without sacrificing the user's mental model.
def dim_idx_arrays(self) -> list[jax.Array]: - **Zero-Cost Operations**: Transforming indices, "folding" dimensions into the output, and other such operations don't actually do anything to the data, enabling a lot of operations to feel "free" in terms of performance.
return [ - **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
dim_idx.realize().values """
if isinstance(dim_idx, RangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
#################### ####################
# - Contravariant Output # - Dimensions: Covariant Index
#################### ####################
# Output Information dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
@property
def output_shape_len(self) -> int:
if self.output_shape is None:
return 0
return len(self.output_shape)
# Pinned Dimension Information
## TODO: Add PhysicalType
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict default_factory=dict
) )
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict @functools.cached_property
) def last_dim(self) -> sim_symbols.SimSymbol | None:
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict) """The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return next(iter(self.dims.keys()))
return None
@functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return list(self.dims.keys())[-1]
return None
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
This happens when the dimension is generated from a symbolic function, as opposed to from discrete observations.
In these cases, the `SimSymbol.domain` of the dimension should be used to determine the overall domain of validity.
Other than that, it's up to the user to select a particular way of indexing.
"""
return self.dims[dim] is None
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim is indexed by a `LabelArray`."""
if dim.mathtype is spux.MathType.Integer:
return isinstance(self.dims[dim], list)
return False
#################### ####################
# - Methods # - Output: Contravariant Value
#################### ####################
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self: output: sim_symbols.SimSymbol
####################
# - Pinned Dimension Values
####################
## -> Whenever a dimension is deleted, we retain what that index value was.
## -> This proves to be very helpful for clear visualization.
pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
default_factory=dict
)
####################
# - Operations: Dimensions
####################
def prepend_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index 0."""
return InfoFlow( return InfoFlow(
# Dimensions dims={dim: dim_idx} | self.dims,
dim_names=self.dim_names, output=self.output,
dim_idx={ pinned_values=self.pinned_values,
_dim_name: (
dim_idx
if _dim_name != dim_name
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
) )
for _dim_name, dim_idx in self.dim_idx.items()
def slice_dim(
self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
) -> typ.Self:
"""Slice a dimensional array by-index along a particular dimension."""
return InfoFlow(
dims={
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else _dim
for _dim, dim_idx in self.dims.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def replace_dim( def replace_dim(
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | RangeFlow] self,
old_dim: sim_symbols.SimSymbol,
new_dim: sim_symbols.SimSymbol,
new_dim_idx: IndexArray,
) -> typ.Self: ) -> typ.Self:
"""Replace a dimension (and its indexing) with a new name and index array/range.""" """Replace a dimension entirely, in-place, including symbol and index array."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=[ (new_dim if _dim == old_dim else _dim): (
dim_name if dim_name != old_dim_name else new_dim_idx[0] new_dim_idx if _dim == old_dim else _dim
for dim_name in self.dim_names
],
dim_idx={
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
) )
for dim_name, dim_idx in self.dim_idx.items() for _dim, dim_idx in self.dims.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def rescale_dim_idxs(self, new_dim_idxs: dict[str, RangeFlow]) -> typ.Self: def replace_dims(
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges.""" """Replace several dimensional indices with new index arrays/ranges."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=self.dim_names, dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
dim_idx={
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
for _dim_name, dim_idx in self.dim_idx.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def delete_dimension(self, dim_name: str) -> typ.Self: def delete_dim(
"""Delete a dimension.""" self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
) -> typ.Self:
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
new_pin = (
{dim_to_remove: self.dims[dim_to_remove][pin_idx]}
if pin_idx is not None
else {}
)
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=[ dim: dim_idx
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name for dim, dim_idx in self.dims.items()
], if dim != dim_to_remove
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values | new_pin,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self: def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
"""Swap the position of two dimensions.""" """Swap the positions of two dimensions."""
# Compute Swapped Dimension Name List # Swapped Dimension Keys
def name_swapper(dim_name): def name_swapper(dim_name):
return ( return (
dim_name dim_name
if dim_name not in [dim_0_name, dim_1_name] if dim_name not in [dim_0, dim_1]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name] else {dim_0: dim_1, dim_1: dim_0}[dim_name]
) )
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names] swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
# Compute Info
return InfoFlow( return InfoFlow(
# Dimensions dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
dim_names=dim_names, output=self.output,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names}, pinned_values=self.pinned_values,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self: ####################
"""Set the MathType of the output.""" # - Operations: Output
####################
def update_output(self, **kwargs) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
return InfoFlow( return InfoFlow(
dim_names=self.dim_names, dims=self.dims,
dim_idx=self.dim_idx, output=self.output.update(**kwargs),
# Outputs pinned_values=self.pinned_values,
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
) )
def collapse_output( ####################
self, # - Operations: Fold
collapsed_name: str, ####################
collapsed_mathtype: spux.MathType, def fold_last_input(self):
collapsed_unit: spux.Unit, """Fold the last input dimension into the output."""
) -> typ.Self: last_key = list(self.dims.keys())[-1]
"""Replace the (scalar) output with the given corrected values.""" last_idx = list(self.dims.values())[-1]
return InfoFlow(
# Dimensions rows = self.output.rows
dim_names=self.dim_names, cols = self.output.cols
dim_idx=self.dim_idx, match (rows, cols):
output_name=collapsed_name, case (1, 1):
output_shape=None, new_output = self.output.set_size(len(last_idx), 1)
output_mathtype=collapsed_mathtype, case (_, 1):
output_unit=collapsed_unit, new_output = self.output.set_size(rows, len(last_idx))
) case (1, _):
new_output = self.output.set_size(len(last_idx), cols)
case (_, _):
raise NotImplementedError ## Not yet :)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=self.dim_names[:-1], dim: dim_idx for dim, dim_idx in self.dims.items() if dim != last_key
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
}, },
# Outputs output=new_output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )

View File

@ -24,6 +24,8 @@ import jax
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from .params import ParamsFlow
log = logger.get(__name__) log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any] LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@ -307,6 +309,25 @@ class FuncFlow:
msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False' msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg) raise ValueError(msg)
####################
# - Realization
####################
def realize(
self,
params: ParamsFlow,
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
) -> typ.Self:
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
return self.func(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
#################### ####################
# - Composition Operations # - Composition Operations
#################### ####################

View File

@ -28,13 +28,20 @@ from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from .array import ArrayFlow from .array import ArrayFlow
from .flow_kinds import FlowKind
from .lazy_func import FuncFlow from .lazy_func import FuncFlow
log = logger.get(__name__) log = logger.get(__name__)
class ScalingMode(enum.StrEnum): class ScalingMode(enum.StrEnum):
"""Identifier for how to space steps between two boundaries.
Attributes:
Lin: Uniform spacing between two endpoints.
Geom: Log spacing between two endpoints, given as values.
Log: Log spacing between two endpoints, given as powers of a common base.
"""
Lin = enum.auto() Lin = enum.auto()
Geom = enum.auto() Geom = enum.auto()
Log = enum.auto() Log = enum.auto()
@ -55,36 +62,20 @@ class ScalingMode(enum.StrEnum):
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class RangeFlow: class RangeFlow:
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation. r"""Represents a spaced array using symbolic boundary expressions.
# Advantages
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
## Memory # Memory Scaling
`ArrayFlow` generally has a memory scaling of $O(n)$. `ArrayFlow` generally has a memory scaling of $O(n)$.
Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored. Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored.
## Symbolic # Symbolic Bounds
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol). `self.start` and `self.stop` boundary points are symbolic expressions, within which any element of `self.symbols` can participate.
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow. **It is the user's responsibility** to ensure that `self.start < self.stop`.
## Performant Unit-Aware Operations # Numerical Properties
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
`RangeFlow`, by contrast, scales as $O(1)$.
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
## High-Performance Composition and Gradiant
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
There are two nice things about this:
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
Thus, as part of larger computations, the performance properties of `RangeFlow` is extremely favorable.
## Numerical Properties
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated. Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes: Attributes:
@ -108,8 +99,11 @@ class RangeFlow:
unit: spux.Unit | None = None unit: spux.Unit | None = None
symbols: frozenset[spux.IntSymbol] = frozenset() symbols: frozenset[spux.Symbol] = frozenset()
####################
# - Computed Properties
####################
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
@ -121,6 +115,19 @@ class RangeFlow:
""" """
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
@property
def is_symbolic(self) -> bool:
"""Whether the `RangeFlow` has unrealized symbols."""
return len(self.symbols) > 0
def __len__(self) -> int:
"""Compute the length of the array that would be realized.
Returns:
The number of steps.
"""
return self.steps
@functools.cached_property @functools.cached_property
def mathtype(self) -> spux.MathType: def mathtype(self) -> spux.MathType:
"""Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`. """Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`.
@ -156,13 +163,206 @@ class RangeFlow:
) )
return combined_mathtype return combined_mathtype
def __len__(self): ####################
"""Compute the length of the array to be realized. # - Methods
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
"""Apply an order-preserving function to each bound, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
return RangeFlow(
start=(
spux.scale_to_unit(new_start, new_unit)
if new_unit is not None
else new_start
),
stop=(
spux.scale_to_unit(new_stop, new_unit)
if new_unit is not None
else new_stop
),
steps=self.steps,
scaling=self.scaling,
unit=new_unit,
symbols=self.symbols,
)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
####################
# - Exporters
####################
@functools.cached_property
def array_generator(
self,
) -> typ.Callable[
[int | float | complex, int | float | complex, int],
jtyp.Inexact[jtyp.Array, ' steps'],
]:
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
Returns: Returns:
The number of steps. A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
""" """
return self.steps jnp_nspace = {
ScalingMode.Lin: jnp.linspace,
ScalingMode.Geom: jnp.geomspace,
ScalingMode.Log: jnp.logspace,
}.get(self.scaling)
if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
return jnp_nspace
@functools.cached_property
def as_func(
self,
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
"""
# Compile JAX Functions for Start/End Expressions
## -> FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
*args: list[int | float | complex],
) -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize_start(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_stop(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_step_size(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
raise NotImplementedError('Non-linear scaling mode not yet suported')
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
return raw_step_size
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
Parameters:
symbol_values: The particular values for each symbol, which will be inserted into the expression of each bound to realize them.
Returns:
An `ArrayFlow` containing this realized `RangeFlow`.
"""
## TODO: Check symbol values for coverage.
return ArrayFlow(
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
unit=self.unit,
is_sorted=True,
)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
"""Standardized access to `self.realize()` when there are no symbols."""
return self.realize()
def __getitem__(self, subscript: slice):
"""Implement indexing and slicing in a sane way.
- **Integer Index**: Not yet implemented.
- **Slice**: Return the `RangeFlow` that creates the same `ArrayFlow` as would be created by computing `self.realize_array`, then slicing that.
"""
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice
start = subscript.start if subscript.start is not None else 0
stop = subscript.stop if subscript.stop is not None else self.steps
step = subscript.step if subscript.step is not None else 1
slice_steps = (stop - start + step - 1) // step
# Compute New Start/Stop
step_size = self.realize_step_size()
new_start = step_size * start
new_stop = new_start + step_size * slice_steps
return RangeFlow(
start=sp.S(new_start),
stop=sp.S(new_stop),
steps=slice_steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
raise NotImplementedError
#################### ####################
# - Units # - Units
@ -264,231 +464,3 @@ class RangeFlow:
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}' f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
) )
raise ValueError(msg) raise ValueError(msg)
####################
# - Bound Operations
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
return RangeFlow(
start=(
spux.scale_to_unit(new_start, new_unit)
if new_unit is not None
else new_start
),
stop=(
spux.scale_to_unit(new_stop, new_unit)
if new_unit is not None
else new_stop
),
steps=self.steps,
scaling=self.scaling,
unit=new_unit,
symbols=self.symbols,
)
def rescale_bounds(
self,
rescale_func: typ.Callable[
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
],
reverse: bool = False,
) -> typ.Self:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `RangeFlow`.
"""
return RangeFlow(
start=rescale_func(self.start if not reverse else self.stop),
stop=rescale_func(self.stop if not reverse else self.start),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
####################
# - Lazy Representation
####################
@functools.cached_property
def array_generator(
self,
) -> typ.Callable[
[int | float | complex, int | float | complex, int],
jtyp.Inexact[jtyp.Array, ' steps'],
]:
"""Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods.
Returns:
A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array.
"""
jnp_nspace = {
ScalingMode.Lin: jnp.linspace,
ScalingMode.Geom: jnp.geomspace,
ScalingMode.Log: jnp.logspace,
}.get(self.scaling)
if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg)
return jnp_nspace
@functools.cached_property
def as_func(
self,
) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]:
"""Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`.
Notes:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns:
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
"""
# Compile JAX Functions for Start/End Expressions
## FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax')
# Compile ArrayGen Function
def gen_array(
*args: list[int | float | complex],
) -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(start_jax(*args), stop_jax(*args), self.steps)
# Return ArrayGen Function
return gen_array
@functools.cached_property
def as_lazy_func(self) -> FuncFlow:
"""Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes:
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns:
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
"""
return FuncFlow(
func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True,
)
####################
# - Realization
####################
def realize_start(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_stop(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
)
def realize_step_size(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | FuncFlow:
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
return raw_step_size
def realize(
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
kind: typ.Literal[FlowKind.Array, FlowKind.Func] = FlowKind.Array,
) -> ArrayFlow | FuncFlow:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `RangeFlow`.
"""
if not set(self.symbols).issubset(set(symbol_values.keys())):
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
raise ValueError(msg)
# Realize Symbols
realized_start = self.realize_start(symbol_values)
realized_stop = self.realize_stop(symbol_values)
# Return Linspace / Logspace
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(realized_start, realized_stop, self.steps)
if kind == FlowKind.Array:
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
if kind == FlowKind.Func:
return FuncFlow(func=gen_array, supports_jax=True)
msg = f'Invalid kind: {kind}'
raise TypeError(msg)
@functools.cached_property
def realize_array(self) -> ArrayFlow:
return self.realize()
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice
start = subscript.start if subscript.start is not None else 0
stop = subscript.stop if subscript.stop is not None else self.steps
step = subscript.step if subscript.step is not None else 1
slice_steps = (stop - start + step - 1) // step
# Compute New Start/Stop
step_size = self.realize_step_size()
new_start = step_size * start
new_stop = new_start + step_size * slice_steps
return RangeFlow(
start=sp.S(new_start),
stop=sp.S(new_stop),
steps=slice_steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
raise NotImplementedError

View File

@ -22,35 +22,22 @@ from types import MappingProxyType
import sympy as sp import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .expr_info import ExprInfo
from .flow_kinds import FlowKind from .flow_kinds import FlowKind
from .info import InfoFlow
# from .info import InfoFlow
log = logger.get(__name__) log = logger.get(__name__)
class ExprInfo(typ.TypedDict):
active_kind: FlowKind
size: spux.NumberSize1D
mathtype: spux.MathType
physical_type: spux.PhysicalType
# Value
default_value: spux.SympyExpr
# Range
default_min: spux.SympyExpr
default_max: spux.SympyExpr
default_steps: int
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class ParamsFlow: class ParamsFlow:
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[spux.Symbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
@ -66,21 +53,18 @@ class ParamsFlow:
#################### ####################
def scaled_func_args( def scaled_func_args(
self, self,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
): ):
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`. """Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
For all `arg`s in `self.func_args`, the following operations are performed: For all `arg`s in `self.func_args`, the following operations are performed.
- **Unit System**: If `arg`
Notes: Notes:
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method: This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
""" """
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
if not all(sym in self.symbols for sym in symbol_values): if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg) raise ValueError(msg)
@ -97,7 +81,7 @@ class ParamsFlow:
def scaled_func_kwargs( def scaled_func_kwargs(
self, self,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
): ):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
@ -145,9 +129,7 @@ class ParamsFlow:
#################### ####################
# - Generate ExprSocketDef # - Generate ExprSocketDef
#################### ####################
def sym_expr_infos( def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
self, info: InfoFlow, use_range: bool = False
) -> dict[str, ExprInfo]:
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`. """Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`. Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
@ -169,26 +151,35 @@ class ParamsFlow:
The `ExprInfo`s can be directly defererenced `**expr_info`) The `ExprInfo`s can be directly defererenced `**expr_info`)
""" """
for sim_sym in self.sorted_symbols:
if use_range and sim_sym.mathtype is spux.MathType.Complex:
msg = 'No support for complex range in ExprInfo'
raise NotImplementedError(msg)
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
msg = 'No support for non-scalar elements of range in ExprInfo'
raise NotImplementedError(msg)
if sim_sym.rows > 3 or sim_sym.cols > 1:
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
raise NotImplementedError(msg)
return { return {
sym.name: { sim_sym.name: {
# Declare Kind/Size # Declare Kind/Size
## -> Kind: Value prevents user-alteration of config. ## -> Kind: Value prevents user-alteration of config.
## -> Size: Always scalar, since symbols are scalar (for now). ## -> Size: Always scalar, since symbols are scalar (for now).
'active_kind': FlowKind.Value, 'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
'size': spux.NumberSize1D.Scalar, 'size': spux.NumberSize1D.Scalar,
# Declare MathType/PhysicalType # Declare MathType/PhysicalType
## -> MathType: Lookup symbol name in info dimensions. ## -> MathType: Lookup symbol name in info dimensions.
## -> PhysicalType: Same. ## -> PhysicalType: Same.
'mathtype': info.dim_mathtypes[sym.name], 'mathtype': self.dims[sim_sym].mathtype,
'physical_type': info.dim_physical_types[sym.name], 'physical_type': self.dims[sim_sym].physical_type,
# TODO: Default Values # TODO: Default Value
# FlowKind.Value: Default Value # FlowKind.Value: Default Value
#'default_value': #'default_value':
# FlowKind.Range: Default Min/Max/Steps # FlowKind.Range: Default Min/Max/Steps
#'default_min': 'default_min': sim_sym.domain.start,
#'default_max': 'default_max': sim_sym.domain.end,
#'default_steps': 'default_steps': 50,
} }
for sym in self.sorted_symbols for sim_sym in self.sorted_symbols
if sym.name in info.dim_names
} }

View File

@ -489,11 +489,11 @@ class DataFileFormat(enum.StrEnum):
E = DataFileFormat E = DataFileFormat
match self: match self:
case E.Csv: case E.Csv:
return len(info.dim_names) + info.output_shape_len <= 2 return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
case E.Npy: case E.Npy:
return True return True
case E.Txt | E.TxtGz: case E.Txt | E.TxtGz:
return len(info.dim_names) + info.output_shape_len <= 2 return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
@property @property
def saver( def saver(
@ -510,9 +510,9 @@ class DataFileFormat(enum.StrEnum):
# Extract Input Coordinates # Extract Input Coordinates
dim_columns = { dim_columns = {
dim_name: np.array(info.dim_idx_arrays[i]) dim.name: np.array(dim_idx.realize_array)
for i, dim_name in enumerate(info.dim_names) for i, (dim, dim_idx) in enumerate(info.dims)
} } ## TODO: realize_array might not be defined on some index arrays
# Declare Function to Extract Output Values # Declare Function to Extract Output Values
output_columns = {} output_columns = {}
@ -524,14 +524,14 @@ class DataFileFormat(enum.StrEnum):
output_idx_str = f'[{output_idx}]' if use_output_idx else '' output_idx_str = f'[{output_idx}]' if use_output_idx else ''
if bool(np.any(np.iscomplex(data_col))): if bool(np.any(np.iscomplex(data_col))):
output_columns |= { output_columns |= {
f'{info.output_name}{output_idx_str}_re': np.real(data_col), f'{info.output.name}{output_idx_str}_re': np.real(data_col),
f'{info.output_name}{output_idx_str}_im': np.imag(data_col), f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
} }
# Else: Use Array Directly # Else: Use Array Directly
else: else:
output_columns |= { output_columns |= {
f'{info.output_name}{output_idx_str}': data_col, f'{info.output.name}{output_idx_str}': data_col,
} }
## TODO: Maybe a check to ensure dtype!=object? ## TODO: Maybe a check to ensure dtype!=object?
@ -605,11 +605,11 @@ class DataFileFormat(enum.StrEnum):
E = DataFileFormat E = DataFileFormat
match self: match self:
case E.Csv: case E.Csv:
return len(info.dim_names) + (info.output_shape_len + 1) <= 2 return len(info.dims) + (info.output.rows + input.outputs.cols - 1) <= 2
case E.Npy: case E.Npy:
return True return True
case E.Txt | E.TxtGz: case E.Txt | E.TxtGz:
return len(info.dim_names) + (info.output_shape_len + 1) <= 2 return len(info.dims) + (info.output.rows + info.output.cols - 1) <= 2
def supports_metadata(self) -> bool: def supports_metadata(self) -> bool:
E = DataFileFormat E = DataFileFormat

View File

@ -48,7 +48,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics # Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2, _PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um, _PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um, _PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um, _PT.HField: spu.ampere / spu.um,
# Mechanical # Mechanical
@ -58,7 +57,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
_PT.Force: spux.micronewton, _PT.Force: spux.micronewton,
# Luminal # Luminal
# Optics # Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
} ## TODO: Load (dynamically?) from addon preferences } ## TODO: Load (dynamically?) from addon preferences
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | { UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
@ -75,11 +73,9 @@ UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics # Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2, _PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um, _PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um, _PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um, _PT.HField: spu.ampere / spu.um,
# Luminal # Luminal
# Optics # Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
} }

View File

@ -17,15 +17,15 @@
"""Implements `ExtractDataNode`.""" """Implements `ExtractDataNode`."""
import enum import enum
import functools
import typing as typ import typing as typ
import bpy import bpy
import jax import jax.numpy as jnp
import numpy as np
import sympy.physics.units as spu import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct from ... import contracts as ct
@ -37,6 +37,176 @@ log = logger.get(__name__)
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
####################
# - Monitor Label Arrays
####################
def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
"""
monitor_data = sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
case 'Permittivity':
return ['eps_xx', 'eps_yy', 'eps_zz']
case 'Flux' | 'FluxTime':
return ['flux']
case (
'FieldProjectionAngle'
| 'FieldProjectionCartesian'
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
"""Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
xarr = getattr(monitor_data, monitor_attr, None)
if xarr is None:
return None
def mk_idx_array(axis: str) -> ct.ArrayFlow:
return ct.ArrayFlow(
values=xarr.get_index(axis).values,
unit=symbols[axis].unit,
is_sorted=True,
)
# Compute InfoFlow from XArray
symbols = {
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
'z': sim_symbols.space_z(spu.micrometer),
# Spherical
'r': sim_symbols.ang_r(spu.micrometer),
'theta': sim_symbols.ang_theta(spu.radian),
'phi': sim_symbols.ang_phi(spu.radian),
# Freq|Time
'f': sim_symbols.freq(spu.hertz),
't': sim_symbols.t(spu.second),
# Power Flux
'flux': sim_symbols.flux(spu.watt),
# Cartesian Fields
'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Spherical Fields
'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Wavevector
'ux': sim_symbols.dir_x(spu.watt),
'uy': sim_symbols.dir_y(spu.watt),
# Diffraction Orders
'orders_x': sim_symbols.diff_order_x(None),
'orders_y': sim_symbols.diff_order_y(None),
}
match monitor_data.type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldTime':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'Flux':
return ct.InfoFlow(
dims={
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FluxTime':
return ct.InfoFlow(
dims={
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionAngle':
return ct.InfoFlow(
dims={
symbols['r']: mk_idx_array('r'),
symbols['theta']: mk_idx_array('theta'),
symbols['phi']: mk_idx_array('phi'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionKSpace':
return ct.InfoFlow(
dims={
symbols['ux']: mk_idx_array('ux'),
symbols['uy']: mk_idx_array('uy'),
symbols['r']: mk_idx_array('r'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'Diffraction':
return ct.InfoFlow(
dims={
symbols['orders_x']: mk_idx_array('orders_x'),
symbols['orders_y']: mk_idx_array('orders_y'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
return None
####################
# - Node
####################
class ExtractDataNode(base.MaxwellSimNode): class ExtractDataNode(base.MaxwellSimNode):
"""Extract data from sockets for further analysis. """Extract data from sockets for further analysis.
@ -45,31 +215,21 @@ class ExtractDataNode(base.MaxwellSimNode):
Monitor Data: Extract `Expr`s from monitor data by-component. Monitor Data: Extract `Expr`s from monitor data by-component.
Attributes: Attributes:
extract_filter: Identifier for data to extract from the input. monitor_attr: Identifier for data to extract from the input.
""" """
node_type = ct.NodeType.ExtractData node_type = ct.NodeType.ExtractData
bl_label = 'Extract' bl_label = 'Extract'
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()}, 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
} }
output_socket_sets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Monitor Data': {'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func)},
} }
#################### ####################
# - Properties # - Properties: Monitor Name
####################
extract_filter: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_extract_filters(),
cb_depends_on={'sim_data_monitor_nametype', 'monitor_data_type'},
)
####################
# - Computed: Sim Data
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name='Sim Data', socket_name='Sim Data',
@ -99,198 +259,49 @@ class ExtractDataNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'sim_data'}) @bl_cache.cached_bl_property(depends_on={'sim_data'})
def sim_data_monitor_nametype(self) -> dict[str, str] | None: def sim_data_monitor_nametype(self) -> dict[str, str] | None:
"""For simulation data, deduces a map from the monitor name to the monitor "type". """Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
Return: Return:
The name to type of monitors in the simulation data. The name to type of monitors in the simulation data.
""" """
if self.sim_data is not None: if self.sim_data is not None:
return { return {
monitor_name: monitor_data.type monitor_name: monitor_data.type.removesuffix('Data')
for monitor_name, monitor_data in self.sim_data.monitor_data.items() for monitor_name, monitor_data in self.sim_data.monitor_data.items()
} }
return None return None
#################### monitor_name: enum.StrEnum = bl_cache.BLField(
# - Computed Properties: Monitor Data enum_cb=lambda self, _: self.search_monitor_names(),
#################### cb_depends_on={'sim_data_monitor_nametype'},
@events.on_value_changed(
socket_name='Monitor Data',
input_sockets={'Monitor Data'},
input_sockets_optional={'Monitor Data': True},
) )
def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data'])
if has_monitor_data:
self.monitor_data = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property() def search_monitor_names(self) -> list[ct.BLEnumElement]:
def monitor_data(self) -> TDMonitorData | None: """Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
"""Extracts the monitor data from the input socket.
Return:
Either the monitor data, if available, or None.
"""
monitor_data = self._compute_input(
'Monitor Data', kind=ct.FlowKind.Value, optional=True
)
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data:
return monitor_data
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data'})
def monitor_data_type(self) -> str | None:
r"""For monitor data, deduces the monitor "type".
- **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain.
- **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor.
- **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain.
For planes, an explicit direction is defined.
For volumes, the the integral of all outgoing energy is stored.
- **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection.
- **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order.
Notes: Notes:
Should be invalidated with (before) `self.monitor_data_attrs`. Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
return self.monitor_data.type.removesuffix('Data')
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data_type'})
def monitor_data_attrs(self) -> list[str] | None:
r"""For monitor data, deduces the valid data-containing attributes.
The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups.
- **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor.
- **Permittivity**: Specifically `['xx', 'yy', 'zz']`.
- **Flux(Time)**: Only `['flux']`.
- **FieldProjection(...)**: All of $r$, $\theta$, $\phi$ for both `E` and `H`.
- **Diffraction**: Same as `FieldProjection`.
Notes:
Should be invalidated after with `self.monitor_data_type`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if hasattr(self.monitor_data, field_component)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return ['xx', 'yy', 'zz']
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return ['flux']
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
return None
####################
# - Extraction Filter Search
####################
def search_extract_filters(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`.
Notes:
Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`. See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns: Returns:
Valid `self.extract_filter` in a format compatible with dynamic `EnumProperty`. Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
""" """
if self.sim_data_monitor_nametype is not None: if self.sim_data_monitor_nametype is not None:
return [ return [
(monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i) (
monitor_name,
monitor_name,
monitor_type + ' Monitor Data',
'',
i,
)
for i, (monitor_name, monitor_type) in enumerate( for i, (monitor_name, monitor_type) in enumerate(
self.sim_data_monitor_nametype.items() self.sim_data_monitor_nametype.items()
) )
] ]
if self.monitor_data_attrs is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return [
(monitor_attr, monitor_attr, f' ε_{monitor_attr}', '', i)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return [
(
monitor_attr,
monitor_attr,
'Power flux integral through the plane / out of the volume',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
return [] return []
#################### ####################
@ -303,10 +314,9 @@ class ExtractDataNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header. Called by Blender to determine the text to place in the node's header.
""" """
has_sim_data = self.sim_data_monitor_nametype is not None has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_attrs is not None
if has_sim_data or has_monitor_data: if has_sim_data:
return f'Extract: {self.extract_filter}' return f'Extract: {self.monitor_name}'
return self.bl_label return self.bl_label
@ -316,336 +326,115 @@ class ExtractDataNode(base.MaxwellSimNode):
Parameters: Parameters:
col: UI target for drawing. col: UI target for drawing.
""" """
col.prop(self, self.blfields['extract_filter'], text='') col.prop(self, self.blfields['monitor_name'], text='')
#################### ####################
# - FlowKind.Value: Sim Data -> Monitor Data # - FlowKind.Func
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Monitor Data',
kind=ct.FlowKind.Value,
# Loaded
props={'extract_filter'},
input_sockets={'Sim Data'},
input_sockets_optional={'Sim Data': True},
)
def compute_monitor_data(
self, props: dict, input_sockets: dict
) -> TDMonitorData | ct.FlowSignal:
"""Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`.
Returns:
Monitor data, if available, else `ct.FlowSignal.FlowPending`.
"""
extract_filter = props['extract_filter']
sim_data = input_sockets['Sim Data']
has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data and extract_filter is not None:
return sim_data.monitor_data[extract_filter]
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Array|Func: Monitor Data -> Expr
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Array,
# Loaded
props={'extract_filter'},
input_sockets={'Monitor Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
)
def compute_expr(
self, props: dict, input_sockets: dict
) -> jax.Array | ct.FlowSignal:
"""Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it.
Uses the internal `xarray` data returned by Tidy3D.
By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy.
Returns:
The data array, if available, else `ct.FlowSignal.FlowPending`.
"""
extract_filter = props['extract_filter']
monitor_data = input_sockets['Monitor Data']
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data and extract_filter is not None:
xarray_data = getattr(monitor_data, extract_filter)
return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
# Trigger
'Expr', 'Expr',
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
# Loaded # Loaded
output_sockets={'Expr'}, props={'monitor_name'},
output_socket_kinds={'Expr': ct.FlowKind.Array}, input_sockets={'Sim Data'},
output_sockets_optional={'Expr': True}, input_socket_kinds={'Sim Data': ct.FlowKind.Value},
) )
def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.FuncFlow | None: def compute_expr(
"""Declare `Expr:Func` by creating a simple function that directly wraps `Expr:Array`. self, props: dict, input_sockets: dict
) -> ct.FuncFlow | ct.FlowSignal:
sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
Returns: has_sim_data = not ct.FlowSignal.check(sim_data)
The composable function array, if available, else `ct.FlowSignal.FlowPending`.
"""
output_expr = output_sockets['Expr']
has_output_expr = not ct.FlowSignal.check(output_expr)
if has_output_expr: if has_sim_data and monitor_name is not None:
return ct.FuncFlow(func=lambda: output_expr.values, supports_jax=True) monitor_data = sim_data.get(monitor_name)
if monitor_data is not None:
# Extract Valid Index Labels
## -> The first output axis will be integer-indexed.
## -> Each integer will have a string label.
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
# Generate FuncFlow Per Index Label
## -> We extract each XArray as an attribute of monitor_data.
## -> We then bind its values into a unique func_flow.
## -> This lets us 'stack' then all along the first axis.
func_flows = []
for idx_label in idx_labels:
xarr = getattr(monitor_data, idx_label)
func_flows.append(
ct.FuncFlow(
func=lambda xarr=xarr: xarr.values,
supports_jax=True,
)
)
# Concatenate and Stack Unified FuncFlow
## -> First, 'reduce' lets us __or__ all the FuncFlows together.
## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
enclosing_func=lambda data: jnp.stack(data, axis=0)
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind.Params: Monitor Data -> Expr # - FlowKind.Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Params},
) )
def compute_data_params(self) -> ct.ParamsFlow: def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline. """Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
Returns: Returns:
A completely empty `ParamsFlow`, ready to be composed. A completely empty `ParamsFlow`, ready to be composed.
""" """
sim_params = input_sockets['Sim Data']
has_sim_params = not ct.FlowSignal.check(sim_params)
if has_sim_params:
return sim_params
return ct.ParamsFlow() return ct.ParamsFlow()
#################### ####################
# - FlowKind.Info: Monitor Data -> Expr # - FlowKind.Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
# Loaded # Loaded
props={'monitor_data_type', 'extract_filter'}, props={'monitor_name'},
input_sockets={'Monitor Data'}, input_sockets={'Sim Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, input_socket_kinds={'Sim Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
) )
def compute_extracted_data_info( def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
self, props: dict, input_sockets: dict
) -> ct.InfoFlow:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type. """Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
Returns: Returns:
Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`. Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
""" """
monitor_data = input_sockets['Monitor Data'] sim_data = input_sockets['Sim Data']
monitor_data_type = props['monitor_data_type'] monitor_name = props['monitor_name']
extract_filter = props['extract_filter']
has_monitor_data = not ct.FlowSignal.check(monitor_data) has_sim_data = not ct.FlowSignal.check(sim_data)
# Edge Case: Dangling 'flux' Access on 'FieldMonitor' if not has_sim_data or monitor_name is None:
## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields.
## -> We don't allow 'flux' attribute access, but it can dangle.
## -> (The method is called when updating each depschain component.)
if monitor_data_type == 'Field' and extract_filter == 'flux':
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Retrieve XArray # Extract Data
if has_monitor_data and extract_filter is not None: ## -> All monitor_data.<idx_label> have the exact same InfoFlow.
xarr = getattr(monitor_data, extract_filter, None) ## -> So, just construct an InfoFlow w/prepended labelled dimension.
if xarr is None: monitor_data = sim_data.get(monitor_name)
return ct.FlowSignal.FlowPending idx_labels = valid_monitor_attrs(sim_data, monitor_name)
else: info = extract_info(monitor_data, idx_labels[0])
return ct.FlowSignal.FlowPending
# Compute InfoFlow from XArray return info.prepend_dim(sim_symbols.idx, idx_labels)
## XYZF: Field / Permittivity / FieldProjectionCartesian
if monitor_data_type in {
'Field',
'Permittivity',
#'FieldProjectionCartesian',
}:
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 'f'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## XYZT: FieldTime
if monitor_data_type == 'FieldTime':
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 't'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.second,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## F: Flux
if monitor_data_type == 'Flux':
return ct.InfoFlow(
dim_names=['f'],
dim_idx={
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## T: FluxTime
if monitor_data_type == 'FluxTime':
return ct.InfoFlow(
dim_names=['t'],
dim_idx={
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## RThetaPhiF: FieldProjectionAngle
if monitor_data_type == 'FieldProjectionAngle':
return ct.InfoFlow(
dim_names=['r', 'theta', 'phi', 'f'],
dim_idx={
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
}
| {
c: ct.ArrayFlow(
values=xarr.get_index(c).values,
unit=spu.radian,
is_sorted=True,
)
for c in ['r', 'theta', 'phi']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## UxUyRF: FieldProjectionKSpace
if monitor_data_type == 'FieldProjectionKSpace':
return ct.InfoFlow(
dim_names=['ux', 'uy', 'r', 'f'],
dim_idx={
c: ct.ArrayFlow(
values=xarr.get_index(c).values, unit=None, is_sorted=True
)
for c in ['ux', 'uy']
}
| {
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## OrderxOrderyF: Diffraction
if monitor_data_type == 'Diffraction':
return ct.InfoFlow(
dim_names=['orders_x', 'orders_y', 'f'],
dim_idx={
f'orders_{c}': ct.ArrayFlow(
values=xarr.get_index(f'orders_{c}').values,
unit=None,
is_sorted=True,
)
for c in ['x', 'y']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"'
raise RuntimeError(msg)
#################### ####################

View File

@ -98,29 +98,29 @@ class FilterOperation(enum.StrEnum):
operations = [] operations = []
# Slice # Slice
if info.dim_names: if info.dims:
operations.append(FO.SliceIdx) operations.append(FO.SliceIdx)
# Pin # Pin
## PinLen1 ## PinLen1
## -> There must be a dimension with length 1. ## -> There must be a dimension with length 1.
if 1 in list(info.dim_lens.values()): if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1) operations.append(FO.PinLen1)
## Pin | PinIdx ## Pin | PinIdx
## -> There must be a dimension, full stop. ## -> There must be a dimension, full stop.
if info.dim_names: if info.dims:
operations += [FO.Pin, FO.PinIdx] operations += [FO.Pin, FO.PinIdx]
# Reinterpret # Reinterpret
## Swap ## Swap
## -> There must be at least two dimensions. ## -> There must be at least two dimensions.
if len(info.dim_names) >= 2: # noqa: PLR2004 if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap) operations.append(FO.Swap)
## SetDim ## SetDim
## -> There must be a dimension to correct. ## -> There must be a dimension to correct.
if info.dim_names: if info.dims:
operations.append(FO.SetDim) operations.append(FO.SetDim)
return operations return operations
@ -158,33 +158,33 @@ class FilterOperation(enum.StrEnum):
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation FO = FilterOperation
match self: match self:
case FO.SliceIdx: case FO.SliceIdx | FO.Swap:
return info.dim_names return info.dims
# PinLen1: Only allow dimensions with length=1. # PinLen1: Only allow dimensions with length=1.
case FO.PinLen1: case FO.PinLen1:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims.items()
if info.dim_lens[dim_name] == 1 if dim_idx is not None and len(dim_idx) == 1
] ]
# Pin: Only allow dimensions with known indexing. # Pin: Only allow dimensions with discrete index.
case FO.Pin: ## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
case FO.Pin | FO.PinIdx:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims
if info.dim_has_coords[dim_name] != 0 if dim_idx is not None and len(dim_idx) > 0
] ]
case FO.PinIdx | FO.Swap:
return info.dim_names
case FO.SetDim: case FO.SetDim:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims
if info.dim_mathtypes[dim_name] == spux.MathType.Integer if dim_idx is not None
and not isinstance(dim_idx, list)
and dim_idx.mathtype == spux.MathType.Integer
] ]
return [] return []
@ -224,22 +224,22 @@ class FilterOperation(enum.StrEnum):
def transform_info( def transform_info(
self, self,
info: ct.InfoFlow, info: ct.InfoFlow,
dim_0: str, dim_0: sim_symbols.SimSymbol,
dim_1: str, dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None, slice_tuple: tuple[int, int, int] | None = None,
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
| None = None,
): ):
FO = FilterOperation FO = FilterOperation
return { return {
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin # Pin
FO.PinLen1: lambda: info.delete_dimension(dim_0), FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.Pin: lambda: info.delete_dimension(dim_0), FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dimension(dim_0), FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret # Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*corrected_dim), FO.SetDim: lambda: info.replace_dim(*replaced_dim),
}[self]() }[self]()
@ -318,11 +318,11 @@ class FilterMathNode(base.MaxwellSimNode):
#################### ####################
# - Properties: Dimension Selection # - Properties: Dimension Selection
#################### ####################
dim_0: enum.StrEnum = bl_cache.BLField( active_dim_0: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(), enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'}, cb_depends_on={'operation', 'expr_info'},
) )
dim_1: enum.StrEnum = bl_cache.BLField( active_dim_1: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(), enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'}, cb_depends_on={'operation', 'expr_info'},
) )
@ -335,40 +335,23 @@ class FilterMathNode(base.MaxwellSimNode):
] ]
return [] return []
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
def dim_0(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_0 is not None:
return self.expr_info.dim_by_name(self.active_dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
def dim_1(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_by_name(self.active_dim_1)
return None
#################### ####################
# - Properties: Slice # - Properties: Slice
#################### ####################
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1]) slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
####################
# - Properties: Unit
####################
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
sim_symbols.CommonSimSymbol.X
)
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'set_dim_symbol'},
)
def search_valid_units(self) -> list[ct.BLEnumElement]:
"""Compute Blender enum elements of valid units for the current `physical_type`."""
physical_type = self.set_dim_symbol.sim_symbol.physical_type
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
def set_dim_unit(self) -> spux.Unit | None:
if self.set_dim_active_unit is not None:
return spux.unit_str_to_unit(self.set_dim_active_unit)
return None
#################### ####################
# - UI # - UI
#################### ####################
@ -378,27 +361,27 @@ class FilterMathNode(base.MaxwellSimNode):
# Slice # Slice
case FO.SliceIdx: case FO.SliceIdx:
slice_str = ':'.join([str(v) for v in self.slice_tuple]) slice_str = ':'.join([str(v) for v in self.slice_tuple])
return f'Filter: {self.dim_0}[{slice_str}]' return f'Filter: {self.active_dim_0}[{slice_str}]'
# Pin # Pin
case FO.PinLen1: case FO.PinLen1:
return f'Filter: Pin {self.dim_0}[0]' return f'Filter: Pin {self.active_dim_0}[0]'
case FO.Pin: case FO.Pin:
return f'Filter: Pin {self.dim_0}[...]' return f'Filter: Pin {self.active_dim_0}[...]'
case FO.PinIdx: case FO.PinIdx:
pin_idx_axis = self._compute_input( pin_idx_axis = self._compute_input(
'Axis', kind=ct.FlowKind.Value, optional=True 'Axis', kind=ct.FlowKind.Value, optional=True
) )
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis) has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
if has_pin_idx_axis: if has_pin_idx_axis:
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]' return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
return self.bl_label return self.bl_label
# Reinterpret # Reinterpret
case FO.Swap: case FO.Swap:
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]' return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
case FO.SetDim: case FO.SetDim:
return f'Filter: Set [{self.dim_0}]' return f'Filter: Set [{self.active_dim_0}]'
case _: case _:
return self.bl_label return self.bl_label
@ -409,20 +392,15 @@ class FilterMathNode(base.MaxwellSimNode):
if self.operation is not None: if self.operation is not None:
match self.operation.num_dim_inputs: match self.operation.num_dim_inputs:
case 1: case 1:
layout.prop(self, self.blfields['dim_0'], text='') layout.prop(self, self.blfields['active_dim_0'], text='')
case 2: case 2:
row = layout.row(align=True) row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='') row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='') row.prop(self, self.blfields['active_dim_1'], text='')
if self.operation is FilterOperation.SliceIdx: if self.operation is FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='') layout.prop(self, self.blfields['slice_tuple'], text='')
if self.operation is FilterOperation.SetDim:
row = layout.row(align=True)
row.prop(self, self.blfields['set_dim_symbol'], text='')
row.prop(self, self.blfields['set_dim_active_unit'], text='')
#################### ####################
# - Events # - Events
#################### ####################
@ -450,50 +428,47 @@ class FilterMathNode(base.MaxwellSimNode):
if not has_info: if not has_info:
return return
# Pin Dim by-Value: Synchronize Input Socket dim_0 = props['dim_0']
## -> The user will be given a socket w/correct mathtype, unit, etc. .
## -> Internally, this value will map to a particular index.
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
# Deduce Pinned Information
pinned_unit = info.dim_units[props['dim_0']]
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
wanted_mathtype = (
spux.MathType.Complex
if pinned_mathtype == spux.MathType.Complex
and spux.MathType.Complex in pinned_physical_type.valid_mathtypes
else spux.MathType.Real
)
# Get Current and Wanted Socket Defs # Loose Sockets: Pin Dim by-Value
## -> 'Value' may already exist. If not, all is well. ## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
props['operation'] is FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Value') current_bl_socket = self.loose_input_sockets.get('Value')
# Determine Whether to Construct
## -> If nothing needs to change, then nothing changes.
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != pinned_physical_type or current_bl_socket.physical_type != dim.physical_type
or current_bl_socket.mathtype != wanted_mathtype or current_bl_socket.mathtype != dim.mathtype
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Value': sockets.ExprSocketDef( 'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, active_kind=ct.FlowKind.Value,
physical_type=pinned_physical_type, physical_type=dim.physical_type,
mathtype=wanted_mathtype, mathtype=dim.mathtype,
default_unit=pinned_unit, default_unit=dim.unit,
), ),
} }
# Pin Dim by-Index: Synchronize Input Socket # Loose Sockets: Pin Dim by-Value
## -> The user will be given a simple integer socket. ## -> Works with discrete points / labelled integers.
elif ( elif (
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None props['operation'] is FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
): ):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Axis') current_bl_socket = self.loose_input_sockets.get('Axis')
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.mathtype != spux.MathType.Integer or current_bl_socket.mathtype != spux.MathType.Integer
@ -505,28 +480,26 @@ class FilterMathNode(base.MaxwellSimNode):
) )
} }
# Set Dim: Synchronize Input Socket # Loose Sockets: Set Dim
## -> The user must provide a () -> array. ## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis. ## -> It must be of identical length to the replaced axis.
elif ( elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
props['operation'] is FilterOperation.SetDim dim = dim_0
and props['dim_0'] is not None
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
):
# Deduce Axis Information
current_bl_socket = self.loose_input_sockets.get('Dim') current_bl_socket = self.loose_input_sockets.get('Dim')
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Func or current_bl_socket.active_kind != ct.FlowKind.Func
or current_bl_socket.mathtype != spux.MathType.Real or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical or current_bl_socket.mathtype != dim.mathtype
or current_bl_socket.physical_type != dim.physical_type
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef( 'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func, active_kind=ct.FlowKind.Func,
mathtype=spux.MathType.Real, physical_type=dim.physical_type,
physical_type=spux.PhysicalType.NonPhysical, mathtype=dim.mathtype,
default_unit=dim.unit,
show_func_ui=False,
show_info_columns=True, show_info_columns=True,
) )
} }
@ -553,25 +526,20 @@ class FilterMathNode(base.MaxwellSimNode):
has_lazy_func = not ct.FlowSignal.check(lazy_func) has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
# Dimension(s)
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if ( if (
has_lazy_func has_lazy_func
and has_info and has_info
and operation is not None and operation is not None
and operation.are_dims_valid(info, dim_0, dim_1) and operation.are_dims_valid(info, dim_0, dim_1)
): ):
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
slice_tuple = (
props['slice_tuple']
if self.operation is FilterOperation.SliceIdx
else None
)
return lazy_func.compose_within( return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple), operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args, enclosing_func_args=operation.func_args,
supports_jax=True, supports_jax=True,
) )
@ -588,8 +556,6 @@ class FilterMathNode(base.MaxwellSimNode):
'dim_1', 'dim_1',
'operation', 'operation',
'slice_tuple', 'slice_tuple',
'set_dim_symbol',
'set_dim_active_unit',
}, },
input_sockets={'Expr', 'Dim'}, input_sockets={'Expr', 'Dim'},
input_socket_kinds={ input_socket_kinds={
@ -601,14 +567,15 @@ class FilterMathNode(base.MaxwellSimNode):
def compute_info(self, props, input_sockets) -> ct.InfoFlow: def compute_info(self, props, input_sockets) -> ct.InfoFlow:
operation = props['operation'] operation = props['operation']
info = input_sockets['Expr'] info = input_sockets['Expr']
dim_coords = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
dim_symbol = props['set_dim_symbol']
dim_active_unit = props['set_dim_active_unit']
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_dim_coords = not ct.FlowSignal.check(dim_coords)
# Dim (Op.SetDim)
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
has_dim_func = not ct.FlowSignal.check(dim_func)
has_dim_params = not ct.FlowSignal.check(dim_params) has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info) has_dim_info = not ct.FlowSignal.check(dim_info)
@ -619,44 +586,42 @@ class FilterMathNode(base.MaxwellSimNode):
if has_info and operation is not None: if has_info and operation is not None:
# Set Dimension: Retrieve Array # Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim: if props['operation'] is FilterOperation.SetDim:
new_dim = (
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
)
if ( if (
dim_0 is not None dim_0 is not None
# Check Replaced Dimension and new_dim is not None
and has_dim_coords
and len(dim_coords.func_args) == 1
and dim_coords.func_args[0] is spux.MathType.Integer
and not dim_coords.func_kwargs
and dim_coords.supports_jax
# Check Params
and has_dim_params
and len(dim_params.func_args) == 1
and not dim_params.func_kwargs
# Check Info
and has_dim_info and has_dim_info
and has_dim_params
# Check New Dimension Index Array Sizing
and len(dim_info.dims) == 1
and dim_info.output.rows == 1
and dim_info.output.cols == 1
# Check Lack of Params Symbols
and not dim_params.symbols
# Check Expr Dim | New Dim Compatibility
and info.has_idx_discrete(dim_0)
and dim_info.has_idx_discrete(new_dim)
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
): ):
# Retrieve Dimension Coordinate Array # Retrieve Dimension Coordinate Array
## -> It must be strictly compatible. ## -> It must be strictly compatible.
values = dim_coords.func_jax(int(dim_params.func_args[0])) values = dim_func.realize(dim_params, spux.UNITS_SI)
if (
len(values.shape) != 1
or values.shape[0] != info.dim_lens[dim_0]
):
return ct.FlowSignal.FlowPending
# Transform Info w/Corrected Dimension # Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced. ## -> The existing dimension will be replaced.
if dim_active_unit is not None:
dim_unit = spux.unit_str_to_unit(dim_active_unit)
else:
dim_unit = None
new_dim_idx = ct.ArrayFlow( new_dim_idx = ct.ArrayFlow(
values=values, values=values,
unit=dim_unit, unit=spux.convert_to_unit_system(
) dim_info.output.unit, spux.UNITS_SI
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)] ),
).rescale_to_unit(dim_info.output.unit)
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
return operation.transform_info( return operation.transform_info(
info, dim_0, dim_1, corrected_dim=corrected_dim info, dim_0, dim_1, replaced_dim=replaced_dim
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple) return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
@ -702,7 +667,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Pin by-Value: Compute Nearest IDX # Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search. ## -> Presume a sorted index array to be able to use binary search.
if props['operation'] is FilterOperation.Pin and has_pinned_value: if props['operation'] is FilterOperation.Pin and has_pinned_value:
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True pinned_value, require_sorted=True
) )

View File

@ -23,7 +23,7 @@ import bpy
import jax.numpy as jnp import jax.numpy as jnp
import sympy as sp import sympy as sp
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
@ -153,16 +153,11 @@ class MapOperation(enum.StrEnum):
# - Ops from Shape # - Ops from Shape
#################### ####################
@staticmethod @staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation MO = MapOperation
element_ops = [
match shape:
case 'noshape':
return []
# By Number
case None:
return [
MO.Real, MO.Real,
MO.Imag, MO.Imag,
MO.Abs, MO.Abs,
@ -178,15 +173,18 @@ class MapOperation(enum.StrEnum):
MO.Sinc, MO.Sinc,
] ]
match len(shape): match (info.output.rows, info.output.cols):
# By Vector case (1, 1):
case 1: return element_ops
return [
MO.Norm2, case (_, 1):
] return [*element_ops, MO.Norm2]
# By Matrix
case 2: case (rows, cols) if rows == cols:
## TODO: Check hermitian/posdef for cholesky.
## - Can we even do this with just the output symbol approach?
return [ return [
*element_ops,
MO.Det, MO.Det,
MO.Cond, MO.Cond,
MO.NormFro, MO.NormFro,
@ -201,6 +199,18 @@ class MapOperation(enum.StrEnum):
MO.Svd, MO.Svd,
] ]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
return [] return []
#################### ####################
@ -288,41 +298,76 @@ class MapOperation(enum.StrEnum):
def transform_info(self, info: ct.InfoFlow): def transform_info(self, info: ct.InfoFlow):
MO = MapOperation MO = MapOperation
return { return {
# By Number # By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector # By Vector
MO.Norm2: lambda: info.collapse_output( MO.Norm2: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('v', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=info.output_unit, cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
# By Matrix # By Matrix
MO.Det: lambda: info.collapse_output( MO.Det: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), rows=1,
collapsed_mathtype=info.output_mathtype, cols=1,
collapsed_unit=info.output_unit,
), ),
MO.Cond: lambda: info.collapse_output( MO.Cond: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=None, cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
), ),
MO.NormFro: lambda: info.collapse_output( MO.NormFro: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=info.output_unit, cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
MO.Rank: lambda: info.collapse_output( MO.Rank: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Integer,
collapsed_mathtype=spux.MathType.Integer, rows=1,
collapsed_unit=None, cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
## TODO: Matrix -> Vec # Matrix -> Vector ## TODO: ALL OF THESE
## TODO: Matrix -> Matrices MO.Diag: lambda: info,
}.get(self, lambda: info)() MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()
#################### ####################
@ -435,29 +480,26 @@ class MapMathNode(base.MaxwellSimNode):
) )
if has_info and not info_pending: if has_info and not info_pending:
self.expr_output_shape = bl_cache.Signal.InvalidateCache self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None: def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True) info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
if has_info: if has_info:
return info.output_shape return info
return None
return 'noshape'
operation: MapOperation = bl_cache.BLField( operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(), enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_output_shape'}, cb_depends_on={'expr_info'},
) )
def search_operations(self) -> list[ct.BLEnumElement]: def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape != 'noshape': if self.info is not None:
return [ return [
operation.bl_enum_element(i) operation.bl_enum_element(i)
for i, operation in enumerate( for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
MapOperation.by_element_shape(self.expr_output_shape)
)
] ]
return [] return []
@ -474,7 +516,7 @@ class MapMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - FlowKind.Value|Func # - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -495,6 +537,9 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Func, kind=ct.FlowKind.Func,
@ -518,7 +563,7 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind.Info|Params # - FlowKind.Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -538,6 +583,9 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,

View File

@ -107,32 +107,31 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform # Covariant Transform
## Freq <-> VacWL ## Freq <-> VacWL
for dim_name in info.dim_names: for dim in info.dims:
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL) operations.append(TO.FreqToVacWL)
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq) operations.append(TO.VacWLToFreq)
# Fold # Fold
## (Last) Int Dim (=2) to Complex ## (Last) Int Dim (=2) to Complex
if len(info.dim_names) >= 1: if len(info.dims) >= 1:
last_dim_name = info.dim_names[-1] if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
operations.append(TO.IntDimToComplex) operations.append(TO.IntDimToComplex)
## To Vector ## To Vector
if len(info.dim_names) >= 1: if len(info.dims) >= 1:
operations.append(TO.DimToVec) operations.append(TO.DimToVec)
## To Matrix ## To Matrix
if len(info.dim_names) >= 2: # noqa: PLR2004 if len(info.dims) >= 2: # noqa: PLR2004
operations.append(TO.DimsToMat) operations.append(TO.DimsToMat)
# Fourier # Fourier
## 1D Fourier ## 1D Fourier
if info.dim_names: if info.dims:
last_physical_type = info.dim_physical_types[info.dim_names[-1]] last_physical_type = info.last_dim.physical_type
if last_physical_type == spux.PhysicalType.Time: if last_physical_type == spux.PhysicalType.Time:
operations.append(TO.FFT1D) operations.append(TO.FFT1D)
if last_physical_type == spux.PhysicalType.Freq: if last_physical_type == spux.PhysicalType.Freq:
@ -188,15 +187,15 @@ class TransformOperation(enum.StrEnum):
unit: spux.Unit | None = None, unit: spux.Unit | None = None,
) -> ct.InfoFlow | None: ) -> ct.InfoFlow | None:
TO = TransformOperation TO = TransformOperation
if not info.dim_names: if not info.dims:
return None return None
return { return {
# Index # Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim( TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.dim_names[-1]), (f_dim := info.last_dim),
[ [
'wl', sim_symbols.wl(spu.nanometer),
info.dim_idx[f_dim].rescale( info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spu.nanometer, new_unit=spu.nanometer,
@ -204,10 +203,10 @@ class TransformOperation(enum.StrEnum):
], ],
), ),
TO.VacWLToFreq: lambda: info.replace_dim( TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.dim_names[-1]), (wl_dim := info.last_dim),
[ [
'f', sim_symbols.freq(spux.THz),
info.dim_idx[wl_dim].rescale( info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spux.THz, new_unit=spux.THz,
@ -215,24 +214,24 @@ class TransformOperation(enum.StrEnum):
], ],
), ),
# Fold # Fold
TO.IntDimToComplex: lambda: info.delete_dimension( TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
info.dim_names[-1] mathtype=spux.MathType.Complex
).set_output_mathtype(spux.MathType.Complex), ),
TO.DimToVec: lambda: info.shift_last_input, TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input, TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier # Fourier
TO.FFT1D: lambda: info.replace_dim( TO.FFT1D: lambda: info.replace_dim(
info.dim_names[-1], info.last_dim,
[ [
'f', sim_symbols.freq(spux.THz),
ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz), None,
], ],
), ),
TO.InvFFT1D: info.replace_dim( TO.InvFFT1D: info.replace_dim(
info.dim_names[-1], info.last_dim,
[ [
't', sim_symbols.t(spu.second),
ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.second), None,
], ],
), ),
}.get(self, lambda: info)() }.get(self, lambda: info)()

View File

@ -38,7 +38,6 @@ class VizMode(enum.StrEnum):
**NOTE**: >1D output dimensions currently have no viz. **NOTE**: >1D output dimensions currently have no viz.
Plots for `() -> `: Plots for `() -> `:
- Hist1D: Bin-summed distribution.
- BoxPlot1D: Box-plot describing the distribution. - BoxPlot1D: Box-plot describing the distribution.
Plots for `() -> `: Plots for `() -> `:
@ -61,7 +60,6 @@ class VizMode(enum.StrEnum):
- Heatmap3D: Colormapped field with value at each voxel. - Heatmap3D: Colormapped field with value at each voxel.
""" """
Hist1D = enum.auto()
BoxPlot1D = enum.auto() BoxPlot1D = enum.auto()
Curve2D = enum.auto() Curve2D = enum.auto()
@ -78,41 +76,37 @@ class VizMode(enum.StrEnum):
@staticmethod @staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
EMPTY = ()
Z = spux.MathType.Integer Z = spux.MathType.Integer
R = spux.MathType.Real R = spux.MathType.Real
VM = VizMode VM = VizMode
valid_viz_modes = { return {
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D], ((Z), (1, 1, R)): [
((Z), (None, R)): [
VM.Hist1D,
VM.BoxPlot1D, VM.BoxPlot1D,
], ],
((R,), (None, R)): [ ((R,), (1, 1, R)): [
VM.Curve2D, VM.Curve2D,
VM.Points2D, VM.Points2D,
VM.Bar, VM.Bar,
], ],
((R, Z), (None, R)): [ ((R, Z), (1, 1, R)): [
VM.Curves2D, VM.Curves2D,
VM.FilledCurves2D, VM.FilledCurves2D,
], ],
((R, R), (None, R)): [ ((R, R), (1, 1, R)): [
VM.Heatmap2D, VM.Heatmap2D,
], ],
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D], ((R, R, R), (1, 1, R)): [
VM.SqueezedHeatmap2D,
VM.Heatmap3D,
],
}.get( }.get(
( (
tuple(info.dim_mathtypes.values()), tuple([dim.mathtype for dim in info.dims.values()]),
(info.output_shape, info.output_mathtype), (info.output.rows, info.output.cols, info.output.mathtype),
),
[],
) )
)
if valid_viz_modes is None:
return []
return valid_viz_modes
@staticmethod @staticmethod
def to_plotter( def to_plotter(
@ -121,7 +115,6 @@ class VizMode(enum.StrEnum):
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]: ]:
return { return {
VizMode.Hist1D: image_ops.plot_hist_1d,
VizMode.BoxPlot1D: image_ops.plot_box_plot_1d, VizMode.BoxPlot1D: image_ops.plot_box_plot_1d,
VizMode.Curve2D: image_ops.plot_curve_2d, VizMode.Curve2D: image_ops.plot_curve_2d,
VizMode.Points2D: image_ops.plot_points_2d, VizMode.Points2D: image_ops.plot_points_2d,
@ -136,7 +129,6 @@ class VizMode(enum.StrEnum):
@staticmethod @staticmethod
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
return { return {
VizMode.Hist1D: 'Histogram',
VizMode.BoxPlot1D: 'Box Plot', VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve', VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points', VizMode.Points2D: 'Points',
@ -164,7 +156,6 @@ class VizTarget(enum.StrEnum):
@staticmethod @staticmethod
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
return { return {
VizMode.Hist1D: [VizTarget.Plot2D],
VizMode.BoxPlot1D: [VizTarget.Plot2D], VizMode.BoxPlot1D: [VizTarget.Plot2D],
VizMode.Curve2D: [VizTarget.Plot2D], VizMode.Curve2D: [VizTarget.Plot2D],
VizMode.Points2D: [VizTarget.Plot2D], VizMode.Points2D: [VizTarget.Plot2D],
@ -333,35 +324,20 @@ class VizNode(base.MaxwellSimNode):
## -> This happens if Params contains not-yet-realized symbols. ## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols: if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != { if set(self.loose_input_sockets) != {
sym.name for sym in params.symbols if sym.name in info.dim_names dim.name for dim in params.symbols if dim in info.dims
}: }:
self.loose_input_sockets = { self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef( dim_name: sockets.ExprSocketDef(**expr_info)
active_kind=ct.FlowKind.Range, for dim_name, expr_info in params.sym_expr_infos(
size=spux.NumberSize1D.Scalar, info, use_range=True
mathtype=info.dim_mathtypes[sym.name], ).items()
physical_type=info.dim_physical_types[sym.name],
default_min=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].start).is_infinite
else sp.S(0)
),
default_max=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].stop).is_infinite
else sp.S(1)
),
default_steps=50,
)
for sym in params.sorted_symbols
if sym.name in info.dim_names
} }
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
##################### #####################
## - Plotting ## - FlowKind.Value
##################### #####################
@events.computes_output_socket( @events.computes_output_socket(
'Preview', 'Preview',
@ -375,8 +351,12 @@ class VizNode(base.MaxwellSimNode):
all_loose_input_sockets=True, all_loose_input_sockets=True,
) )
def compute_dummy_value(self, props, input_sockets, loose_input_sockets): def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
"""Needed for the plot to regenerate in the viewer."""
return ct.FlowSignal.NoFlow return ct.FlowSignal.NoFlow
#####################
## - On Show Plot
#####################
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'}, props={'viz_mode', 'viz_target', 'colormap'},
@ -384,13 +364,12 @@ class VizNode(base.MaxwellSimNode):
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params} 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
}, },
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
all_loose_input_sockets=True, all_loose_input_sockets=True,
stop_propagation=True, stop_propagation=True,
) )
def on_show_plot( def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems self, managed_objs, props, input_sockets, loose_input_sockets
): ) -> None:
# Retrieve Inputs # Retrieve Inputs
lazy_func = input_sockets['Expr'][ct.FlowKind.Func] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
@ -399,8 +378,6 @@ class VizNode(base.MaxwellSimNode):
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params) has_params = not ct.FlowSignal.check(params)
# Invalid Mode | Target
## -> To limit branching, return now if things aren't right.
if ( if (
not has_info not has_info
or not has_params or not has_params
@ -410,52 +387,41 @@ class VizNode(base.MaxwellSimNode):
return return
# Compute Ranges for Symbols from Loose Sockets # Compute Ranges for Symbols from Loose Sockets
## -> These are the concrete values of the symbol for plotting.
## -> In a quite nice turn of events, all this is cached lookups. ## -> In a quite nice turn of events, all this is cached lookups.
## -> ...Unless something changed, in which case, well. It changed. ## -> ...Unless something changed, in which case, well. It changed.
symbol_values = { symbol_array_values = {
sym: ( sim_syms: (
loose_input_sockets[sym.name] loose_input_sockets[sim_syms]
.realize_array.rescale_to_unit(info.dim_units[sym.name]) .rescale_to_unit(sim_syms.unit)
.values .realize_array
) )
for sym in params.sorted_symbols for sim_syms in params.sorted_symbols
} }
data = lazy_func.realize(params, symbol_values=symbol_array_values)
# Realize Func w/Symbolic Values, Unit System
## -> This gives us the actual plot data!
data = lazy_func.func_jax(
*params.scaled_func_args(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
**params.scaled_func_kwargs(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
)
# Replace InfoFlow Indices w/Realized Symbolic Ranges # Replace InfoFlow Indices w/Realized Symbolic Ranges
## -> This ensures correct axis scaling. ## -> This ensures correct axis scaling.
if params.symbols: if params.symbols:
info = info.rescale_dim_idxs(loose_input_sockets) info = info.replace_dims(symbol_array_values)
# Visualize by-Target match props['viz_target']:
if props['viz_target'] == VizTarget.Plot2D: case VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image( managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax), lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True, bl_select=True,
) )
if props['viz_target'] == VizTarget.Pixels: case VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image( managed_objs['plot'].map_2d_to_image(
data, data,
colormap=props['colormap'], colormap=props['colormap'],
bl_select=True, bl_select=True,
) )
if props['viz_target'] == VizTarget.PixelsPlane: case VizTarget.PixelsPlane:
raise NotImplementedError raise NotImplementedError
if props['viz_target'] == VizTarget.Voxels: case VizTarget.Voxels:
raise NotImplementedError raise NotImplementedError

View File

@ -67,7 +67,7 @@ ManagedObjName: typ.TypeAlias = str
PropName: typ.TypeAlias = str PropName: typ.TypeAlias = str
def event_decorator( def event_decorator( # noqa: PLR0913
event: ct.FlowEvent, event: ct.FlowEvent,
callback_info: EventCallbackInfo | None, callback_info: EventCallbackInfo | None,
stop_propagation: bool = False, stop_propagation: bool = False,
@ -91,31 +91,42 @@ def event_decorator(
scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
): ):
"""Returns a decorator for a method of `MaxwellSimNode`, declaring it as able respond to events passing through a node. """Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
Should generally be used via a high-level decorator such as `on_value_changed`.
For more about how event methods are actually registered and run, please refer to the documentation of `MaxwellSimNode`.
Parameters: Parameters:
event: A name describing which event the decorator should respond to. event: A name describing which event the decorator should respond to.
Set to `return_method.event`
callback_info: A dictionary that provides the caller with additional per-`event` information. callback_info: A dictionary that provides the caller with additional per-`event` information.
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback. This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
props: Set of `props` to compute, then pass to the decorated method.
stop_propagation: Whether or stop propagating the event through the graph after encountering this method. stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
Other methods defined on the same node will still run. Other methods defined on the same node will still run.
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method. managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
props: Set of `props` to compute, then pass to the decorated method.
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method. input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
input_sockets_optional: Whether an input socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
input_socket_kinds: The `ct.FlowKind` to compute per-input-socket. input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
If an input socket isn't specified, it defaults to `ct.FlowKind.Value`. If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method. output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
output_sockets_optional: Whether an output socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method. all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
Used when the names of the loose input sockets are unknown, but all of their values are needed. Used when the names of the loose input sockets are unknown, but all of their values are needed.
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method. all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
Used when the names of the loose output sockets are unknown, but all of their values are needed. Used when the names of the loose output sockets are unknown, but all of their values are needed.
unit_systems: String identifiers under which to load a unit system, made available to the method.
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
Returns: Returns:
A decorator, which can be applied to a method of `MaxwellSimNode`. A decorator, which can be applied to a method of `MaxwellSimNode` to make it an "event method".
When a `MaxwellSimNode` subclass initializes, such a decorated method will be picked up on.
When `event` passes through the node, then `callback_info` is used to determine
""" """
req_params = ( req_params = (
{'self'} {'self'}
@ -375,7 +386,6 @@ def on_value_changed(
) )
## TODO: Change name to 'on_output_requested'
def computes_output_socket( def computes_output_socket(
output_socket_name: ct.SocketName | None, output_socket_name: ct.SocketName | None,
kind: ct.FlowKind = ct.FlowKind.Value, kind: ct.FlowKind = ct.FlowKind.Value,

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ import typing as typ
from pathlib import Path from pathlib import Path
@ -21,7 +22,7 @@ import bpy
import sympy as sp import sympy as sp
import tidy3d as td import tidy3d as td
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
@ -91,6 +92,88 @@ class DataFileImporterNode(base.MaxwellSimNode):
return info return info
return None return None
####################
# - Info Guides
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
output_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
output_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'output_physical_type'},
)
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_0_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'dim_0_physical_type'},
)
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerB
)
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_1_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
cb_depends_on={'dim_1_physical_type'},
)
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerC
)
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_2_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
cb_depends_on={'dim_2_physical_type'},
)
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerD
)
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_3_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
cb_depends_on={'dim_3_physical_type'},
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
def dim(self, i: int):
dim_name = getattr(self, f'dim_{i}_name')
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
dim_unit = getattr(self, f'dim_{i}_unit')
return sim_symbols.SimSymbol(
sym_name=dim_name,
mathtype=dim_mathtype,
physical_type=dim_physical_type,
unit=spux.unit_str_to_unit(dim_unit),
)
#################### ####################
# - UI # - UI
#################### ####################
@ -118,7 +201,20 @@ class DataFileImporterNode(base.MaxwellSimNode):
row.label(text=self.file_path.name) row.label(text=self.file_path.name)
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
pass """Draw loaded properties."""
for i in range(len(self.expr_info.dims)):
col = layout.column(align=True)
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text=f'Load Dim {i}')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
#################### ####################
# - FlowKind.Array|Func # - FlowKind.Array|Func
@ -174,10 +270,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
# Loaded
props={'output_name', 'output_physical_type', 'output_unit'},
output_sockets={'Expr'}, output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Func}, output_socket_kinds={'Expr': ct.FlowKind.Func},
) )
def compute_info(self, output_sockets) -> ct.InfoFlow: def compute_info(self, props, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape. """Declare an `InfoFlow` based on the data shape.
This currently requires computing the data. This currently requires computing the data.
@ -196,26 +294,24 @@ class DataFileImporterNode(base.MaxwellSimNode):
# Deduce Dimensionality # Deduce Dimensionality
_shape = data.shape _shape = data.shape
shape = _shape if _shape is not None else () shape = _shape if _shape is not None else ()
dim_names = [f'a{i}' for i in range(len(shape))] dim_syms = [self.dim(i) for i in range(len(shape))]
# Return InfoFlow # Return InfoFlow
## -> TODO: How to interpret the data should be user-defined.
## -> -- This may require those nice dynamic symbols.
return ct.InfoFlow( return ct.InfoFlow(
dim_names=dim_names, ## TODO: User dims={
dim_idx={ dim_sym: ct.RangeFlow(
dim_name: ct.RangeFlow( start=sp.S(0),
start=sp.S(0), ## TODO: User stop=sp.S(shape[i] - 1),
stop=sp.S(shape[i] - 1), ## TODO: User steps=shape[i],
steps=shape[dim_names.index(dim_name)], unit=self.dim(i).unit,
unit=None, ## TODO: User
) )
for i, dim_name in enumerate(dim_names) for i, dim_sym in enumerate(dim_syms)
}, },
output_name='_', output=sim_symbols.SimSymbol(
output_shape=None, sym_name=props['output_name'],
output_mathtype=spux.MathType.Real, ## TODO: User mathtype=props['output_mathtype'],
output_unit=None, ## TODO: User physical_type=props['output_physical_type'],
),
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -229,11 +229,11 @@ class DataFileExporterNode(base.MaxwellSimNode):
## -> Only happens if Params contains not-yet-realized symbols. ## -> Only happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols: if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != { if set(self.loose_input_sockets) != {
sym.name for sym in params.symbols if sym.name in info.dim_names dim.name for dim in params.symbols if dim in info.dims
}: }:
self.loose_input_sockets = { self.loose_input_sockets = {
sym_name: sockets.ExprSocketDef(**expr_info) dim_name: sockets.ExprSocketDef(**expr_info)
for sym_name, expr_info in params.sym_expr_infos(info).items() for dim_name, expr_info in params.sym_expr_infos(info).items()
} }
elif self.loose_input_sockets: elif self.loose_input_sockets:

View File

@ -118,13 +118,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
# Symbols # Symbols
# active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
symbols: frozenset[sp.Symbol] = bl_cache.BLField(frozenset()) sim_symbols.SimSymbolName.Expr
)
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
# @property @property
# def symbols(self) -> set[sp.Symbol]: def symbols(self) -> set[sp.Symbol]:
# """Current symbols as an unordered set.""" """Current symbols as an unordered set."""
# return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols} return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'}) @bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
@ -184,6 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# UI: Info # UI: Info
show_func_ui: bool = bl_cache.BLField(True)
show_info_columns: bool = bl_cache.BLField(False) show_info_columns: bool = bl_cache.BLField(False)
info_columns: set[InfoDisplayCol] = bl_cache.BLField( info_columns: set[InfoDisplayCol] = bl_cache.BLField(
{InfoDisplayCol.Length, InfoDisplayCol.MathType} {InfoDisplayCol.Length, InfoDisplayCol.MathType}
@ -615,35 +618,24 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
""" """
output_sim_sym = (
sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
),
)
if self.symbols: if self.symbols:
return ct.InfoFlow( return ct.InfoFlow(
dim_names=[sym.name for sym in self.sorted_symbols], dims={sim_sym: None for sim_sym in self.active_symbols},
dim_idx={ output=output_sim_sym,
sym.name: ct.RangeFlow(
start=-sp.oo if _check_sym_oo(sym) else -sp.zoo,
stop=sp.oo if _check_sym_oo(sym) else sp.zoo,
steps=0,
unit=None, ## Symbols alone are unitless.
)
## TODO: PhysicalTypes for symbols? Or nah?
## TODO: Can we parse some sp.Interval for explicit domains?
## -> We investigated sp.Symbol(..., domain=...).
## -> It's no good. We can't re-extract the interval given to domain.
for sym in self.sorted_symbols
},
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
) )
# Constant # Constant
return ct.InfoFlow( return ct.InfoFlow(output=output_sim_sym)
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
)
#################### ####################
# - FlowKind: Capabilities # - FlowKind: Capabilities
@ -847,6 +839,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
Uses `draw_value` to draw the base UI Uses `draw_value` to draw the base UI
""" """
if self.show_func_ui:
# Output Name Selector
## -> The name of the output
col.prop(self, self.blfields['output_name'], text='')
# Physical Type Selector # Physical Type Selector
## -> Determines whether/which unit-dropdown will be shown. ## -> Determines whether/which unit-dropdown will be shown.
col.prop(self, self.blfields['physical_type'], text='') col.prop(self, self.blfields['physical_type'], text='')
@ -884,9 +881,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# Dimensions # Dimensions
for dim_name in info.dim_names: for dim in info.dims:
dim_idx = info.dim_idx[dim_name] dim_idx = info.dims[dim]
grid.label(text=dim_name) grid.label(text=dim.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text=str(len(dim_idx))) grid.label(text=str(len(dim_idx)))
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
@ -895,27 +892,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
grid.label(text=spux.sp_to_str(dim_idx.unit)) grid.label(text=spux.sp_to_str(dim_idx.unit))
# Outputs # Outputs
grid.label(text=info.output_name) grid.label(text=info.output.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput) grid.label(text='', icon=ct.Icon.DataSocketOutput)
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
grid.label( grid.label(
text=( text=(
spux.MathType.to_str(info.output_mathtype) spux.MathType.to_str(info.output.mathtype)
+ ( + (
'ˣ'.join( 'ˣ'.join(
[ [
unicode_superscript(out_axis) unicode_superscript(out_axis)
for out_axis in info.output_shape for out_axis in info.output.shape
] ]
) )
if info.output_shape if info.output.shape
else '' else ''
) )
) )
) )
if InfoDisplayCol.Unit in self.info_columns: if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=f'{spux.sp_to_str(info.output_unit)}') grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
#################### ####################
@ -929,6 +926,7 @@ class ExprSocketDef(base.SocketDef):
ct.FlowKind.Array, ct.FlowKind.Array,
ct.FlowKind.Func, ct.FlowKind.Func,
] = ct.FlowKind.Value ] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
# Socket Interface # Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -938,10 +936,6 @@ class ExprSocketDef(base.SocketDef):
default_unit: spux.Unit | None = None default_unit: spux.Unit | None = None
default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list) default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
@property
def symbols(self) -> set[sp.Symbol]:
return {sim_symbol.sp_symbol for sim_symbol in self.default_symbols}
# FlowKind: Value # FlowKind: Value
default_value: spux.SympyExpr = 0 default_value: spux.SympyExpr = 0
abs_min: spux.SympyExpr | None = None abs_min: spux.SympyExpr | None = None
@ -954,6 +948,7 @@ class ExprSocketDef(base.SocketDef):
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# UI # UI
show_func_ui: bool = True
show_info_columns: bool = False show_info_columns: bool = False
#################### ####################
@ -1153,6 +1148,14 @@ class ExprSocketDef(base.SocketDef):
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}' msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg) raise ValueError(msg)
# Coerce from Infinite
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None: if new_bounds[0] is not None:
self.default_min = new_bounds[0] self.default_min = new_bounds[0]
if new_bounds[1] is not None: if new_bounds[1] is not None:
@ -1217,13 +1220,14 @@ class ExprSocketDef(base.SocketDef):
#################### ####################
def init(self, bl_socket: ExprBLSocket) -> None: def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind bl_socket.active_kind = self.active_kind
bl_socket.output_name = self.output_name
# Socket Interface # Socket Interface
## -> Recall that auto-updates are turned off during init() ## -> Recall that auto-updates are turned off during init()
bl_socket.size = self.size bl_socket.size = self.size
bl_socket.mathtype = self.mathtype bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type bl_socket.physical_type = self.physical_type
bl_socket.symbols = self.symbols bl_socket.active_symbols = self.symbols
# FlowKind.Value # FlowKind.Value
## -> We must take units into account when setting bl_socket.value ## -> We must take units into account when setting bl_socket.value
@ -1246,6 +1250,7 @@ class ExprSocketDef(base.SocketDef):
) )
# UI # UI
bl_socket.show_func_ui = self.show_func_ui
bl_socket.show_info_columns = self.show_info_columns bl_socket.show_info_columns = self.show_info_columns
# Info Draw # Info Draw

View File

@ -61,7 +61,6 @@ SympyType = (
class MathType(enum.StrEnum): class MathType(enum.StrEnum):
"""Type identifiers that encompass common sets of mathematical objects.""" """Type identifiers that encompass common sets of mathematical objects."""
Bool = enum.auto()
Integer = enum.auto() Integer = enum.auto()
Rational = enum.auto() Rational = enum.auto()
Real = enum.auto() Real = enum.auto()
@ -77,8 +76,6 @@ class MathType(enum.StrEnum):
return MathType.Rational return MathType.Rational
if MathType.Integer in mathtypes: if MathType.Integer in mathtypes:
return MathType.Integer return MathType.Integer
if MathType.Bool in mathtypes:
return MathType.Bool
msg = f"Can't combine mathtypes {mathtypes}" msg = f"Can't combine mathtypes {mathtypes}"
raise ValueError(msg) raise ValueError(msg)
@ -88,7 +85,6 @@ class MathType(enum.StrEnum):
return ( return (
other other
in { in {
MT.Bool: [MT.Bool],
MT.Integer: [MT.Integer], MT.Integer: [MT.Integer],
MT.Rational: [MT.Integer, MT.Rational], MT.Rational: [MT.Integer, MT.Rational],
MT.Real: [MT.Integer, MT.Rational, MT.Real], MT.Real: [MT.Integer, MT.Rational, MT.Real],
@ -98,11 +94,9 @@ class MathType(enum.StrEnum):
def coerce_compatible_pyobj( def coerce_compatible_pyobj(
self, pyobj: bool | int | Fraction | float | complex self, pyobj: bool | int | Fraction | float | complex
) -> bool | int | Fraction | float | complex: ) -> int | Fraction | float | complex:
MT = MathType MT = MathType
match self: match self:
case MT.Bool:
return pyobj
case MT.Integer: case MT.Integer:
return int(pyobj) return int(pyobj)
case MT.Rational if isinstance(pyobj, int): case MT.Rational if isinstance(pyobj, int):
@ -123,8 +117,6 @@ class MathType(enum.StrEnum):
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)] *[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
) )
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
return MathType.Bool
if sp_obj.is_integer: if sp_obj.is_integer:
return MathType.Integer return MathType.Integer
if sp_obj.is_rational: if sp_obj.is_rational:
@ -146,7 +138,6 @@ class MathType(enum.StrEnum):
@staticmethod @staticmethod
def from_pytype(dtype: type) -> type: def from_pytype(dtype: type) -> type:
return { return {
bool: MathType.Bool,
int: MathType.Integer, int: MathType.Integer,
Fraction: MathType.Rational, Fraction: MathType.Rational,
float: MathType.Real, float: MathType.Real,
@ -166,7 +157,6 @@ class MathType(enum.StrEnum):
def pytype(self) -> type: def pytype(self) -> type:
MT = MathType MT = MathType
return { return {
MT.Bool: bool,
MT.Integer: int, MT.Integer: int,
MT.Rational: Fraction, MT.Rational: Fraction,
MT.Real: float, MT.Real: float,
@ -177,17 +167,25 @@ class MathType(enum.StrEnum):
def symbolic_set(self) -> type: def symbolic_set(self) -> type:
MT = MathType MT = MathType
return { return {
MT.Bool: sp.Set([sp.S(False), sp.S(True)]),
MT.Integer: sp.Integers, MT.Integer: sp.Integers,
MT.Rational: sp.Rationals, MT.Rational: sp.Rationals,
MT.Real: sp.Reals, MT.Real: sp.Reals,
MT.Complex: sp.Complexes, MT.Complex: sp.Complexes,
}[self] }[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
return {
MT.Integer: sp.Symbol('a', integer=True),
MT.Rational: sp.Symbol('a', rational=True),
MT.Real: sp.Symbol('a', real=True),
MT.Complex: sp.Symbol('a', complex=True),
}[self]
@staticmethod @staticmethod
def to_str(value: typ.Self) -> type: def to_str(value: typ.Self) -> type:
return { return {
MathType.Bool: 'T|F',
MathType.Integer: '', MathType.Integer: '',
MathType.Rational: '', MathType.Rational: '',
MathType.Real: '', MathType.Real: '',
@ -212,6 +210,9 @@ class MathType(enum.StrEnum):
) )
####################
# - Size: 1D
####################
class NumberSize1D(enum.StrEnum): class NumberSize1D(enum.StrEnum):
"""Valid 1D-constrained shape.""" """Valid 1D-constrained shape."""
@ -278,6 +279,20 @@ class NumberSize1D(enum.StrEnum):
(4, 1): NS.Vec4, (4, 1): NS.Vec4,
}[shape] }[shape]
@property
def rows(self):
NS = NumberSize1D
return {
NS.Scalar: 1,
NS.Vec2: 2,
NS.Vec3: 3,
NS.Vec4: 4,
}[self]
@property
def cols(self):
return 1
@property @property
def shape(self): def shape(self):
NS = NumberSize1D NS = NumberSize1D
@ -297,6 +312,30 @@ def symbol_range(sym: sp.Symbol) -> str:
) )
####################
# - Symbol Sizes
####################
class SimpleSize2D(enum.StrEnum):
"""Simple subset of sizes for rank-2 tensors."""
Scalar = enum.auto()
# Vectors
Vec2 = enum.auto() ## 2x1
Vec3 = enum.auto() ## 3x1
Vec4 = enum.auto() ## 4x1
# Covectors
CoVec2 = enum.auto() ## 1x2
CoVec3 = enum.auto() ## 1x3
CoVec4 = enum.auto() ## 1x4
# Square Matrices
Mat22 = enum.auto() ## 2x2
Mat33 = enum.auto() ## 3x3
Mat44 = enum.auto() ## 4x4
#################### ####################
# - Unit Dimensions # - Unit Dimensions
#################### ####################
@ -382,6 +421,8 @@ UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} } | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
#################### ####################
# - Expr Analysis: Units # - Expr Analysis: Units
@ -907,10 +948,6 @@ class PhysicalType(enum.StrEnum):
LumIntensity = enum.auto() LumIntensity = enum.auto()
LumFlux = enum.auto() LumFlux = enum.auto()
Illuminance = enum.auto() Illuminance = enum.auto()
# Optics
OrdinaryWaveVector = enum.auto()
AngularWaveVector = enum.auto()
PoyntingVector = enum.auto()
@functools.cached_property @functools.cached_property
def unit_dim(self): def unit_dim(self):
@ -956,10 +993,6 @@ class PhysicalType(enum.StrEnum):
PT.LumIntensity: Dims.luminous_intensity, PT.LumIntensity: Dims.luminous_intensity,
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
PT.Illuminance: Dims.luminous_intensity / Dims.length**2, PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
# Optics
PT.OrdinaryWaveVector: Dims.frequency,
PT.AngularWaveVector: Dims.angle * Dims.frequency,
PT.PoyntingVector: Dims.power / Dims.length**2,
}[self] }[self]
@functools.cached_property @functools.cached_property
@ -1196,10 +1229,6 @@ class PhysicalType(enum.StrEnum):
PT.HField: [None, (2,), (3,)], PT.HField: [None, (2,), (3,)],
# Luminal # Luminal
PT.LumFlux: [None, (2,), (3,)], PT.LumFlux: [None, (2,), (3,)],
# Optics
PT.OrdinaryWaveVector: [None, (2,), (3,)],
PT.AngularWaveVector: [None, (2,), (3,)],
PT.PoyntingVector: [None, (2,), (3,)],
} }
return overrides.get(self, [None]) return overrides.get(self, [None])
@ -1222,7 +1251,6 @@ class PhysicalType(enum.StrEnum):
- **Charge**: Generally, it is real. - **Charge**: Generally, it is real.
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787> However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
- **Poynting**: The imaginary part represents the oscillation in the power flux over time.
""" """
MT = MathType MT = MathType
@ -1249,10 +1277,6 @@ class PhysicalType(enum.StrEnum):
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
# Luminal # Luminal
# Optics
PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power
} }
return overrides.get(self, [MT.Real]) return overrides.get(self, [MT.Real])
@ -1323,10 +1347,6 @@ UNITS_SI: UnitSystem = {
_PT.LumIntensity: spu.candela, _PT.LumIntensity: spu.candela,
_PT.LumFlux: lumen, _PT.LumFlux: lumen,
_PT.Illuminance: spu.lux, _PT.Illuminance: spu.lux,
# Optics
_PT.OrdinaryWaveVector: spu.hertz,
_PT.AngularWaveVector: spu.radian * spu.hertz,
_PT.PoyntingVector: spu.watt / spu.meter**2,
} }
@ -1380,15 +1400,20 @@ def sympy_to_python(
#################### ####################
# - Convert to Unit System # - Convert to Unit System
#################### ####################
def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system, with appropriate scaling.""" """Convert an expression to the units of a given unit system, with appropriate scaling."""
if unit_system is None:
return sp_obj
return spu.convert_to( return spu.convert_to(
sp_obj, sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
) )
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression. """Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
@ -1397,11 +1422,13 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
Notes: Notes:
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
""" """
if unit_system is None:
return sp_obj.subs(UNIT_TO_1)
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def scale_to_unit_system( def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array: ) -> int | float | complex | tuple | jax.Array:
"""Convert an expression to the units of a given unit system, then strip all units of the unit system. """Convert an expression to the units of a given unit system, then strip all units of the unit system.

View File

@ -29,11 +29,13 @@ import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg import matplotlib.backends.backend_agg
import matplotlib.figure import matplotlib.figure
import matplotlib.style as mplstyle import matplotlib.style as mplstyle
import seaborn as sns
from blender_maxwell import contracts as ct from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
mplstyle.use('fast') ## TODO: Does this do anything? mplstyle.use('fast') ## TODO: Does this do anything?
sns.set_theme()
log = logger.get(__name__) log = logger.get(__name__)
@ -149,125 +151,98 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
#################### ####################
# - Plotters # - Plotters
#################### ####################
# () ->
def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None:
y_name = info.output_name
y_unit = info.output_unit
ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# () -> # () ->
def plot_box_plot_1d( def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.last_dim
y_name = info.output_name y_sym = info.output
y_unit = info.output_unit
ax.boxplot(data) ax.boxplot([data])
ax.set_title('Box Plot') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}') ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_sym = info.last_dim
y_sym = info.output
p = ax.bar(info.dims[x_sym], data)
ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
# () -> # () ->
def plot_curve_2d( def plot_curve_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
times = [time.perf_counter()] x_sym = info.last_dim
y_sym = info.output
x_name = info.dim_names[0] ax.plot(info.dims[x_sym].realize_array.values, data)
x_unit = info.dim_units[x_name] ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
y_name = info.output_name ax.set_xlabel(x_sym.plot_label)
y_unit = info.output_unit ax.set_xlabel(y_sym.plot_label)
times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data)
times.append(time.perf_counter() - times[0])
ax.set_title('2D Curve')
times.append(time.perf_counter() - times[0])
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
times.append(time.perf_counter() - times[0])
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
times.append(time.perf_counter() - times[0])
# log.critical('Timing of Curve2D: %s', str(times))
def plot_points_2d( def plot_points_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.last_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) ax.scatter(x_sym.realize_array.values, data)
ax.set_title('2D Points') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# (, ) -> # (, ) ->
def plot_curves_2d( def plot_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
for category in range(data.shape[1]): for i, category in enumerate(info.dims[info.last_dim]):
ax.plot(info.dim_idx_arrays[0], data[:, category]) ax.plot(info.dims[x_sym], data[:, i], label=category)
ax.set_title('2D Curves') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
ax.legend() ax.legend()
def plot_filled_curves_2d( def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
shared_x_idx = info.dim_idx_arrays[0] shared_x_idx = info.dims[info.last_dim]
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1]) ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
ax.set_title('2D Filled Curves') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
ax.legend()
# (, ) -> # (, ) ->
def plot_heatmap_2d( def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.last_dim
y_name = info.dim_names[1] c_sym = info.output
y_unit = info.dim_units[y_name]
heatmap = ax.imshow(data, aspect='auto', interpolation='none') heatmap = ax.imshow(data, aspect='equal', interpolation='none')
# ax.figure.colorbar(heatmap, ax=ax) ax.figure.colorbar(heatmap, cax=ax)
ax.set_title('Heatmap')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()

View File

@ -18,26 +18,67 @@ import dataclasses
import enum import enum
import sys import sys
import typing as typ import typing as typ
from fractions import Fraction
import sympy as sp import sympy as sp
from . import extra_sympy_units as spux from . import extra_sympy_units as spux
int_min = -(2**64)
int_max = 2**64
float_min = sys.float_info.min
float_max = sys.float_info.max
#################### ####################
# - Simulation Symbols # - Simulation Symbol Names
#################### ####################
class SimSymbolName(enum.StrEnum): class SimSymbolName(enum.StrEnum):
# Lower
LowerA = enum.auto() LowerA = enum.auto()
LowerB = enum.auto()
LowerC = enum.auto()
LowerD = enum.auto()
LowerI = enum.auto()
LowerT = enum.auto() LowerT = enum.auto()
LowerX = enum.auto() LowerX = enum.auto()
LowerY = enum.auto() LowerY = enum.auto()
LowerZ = enum.auto() LowerZ = enum.auto()
# Physics # Fields
Ex = enum.auto()
Ey = enum.auto()
Ez = enum.auto()
Hx = enum.auto()
Hy = enum.auto()
Hz = enum.auto()
Er = enum.auto()
Etheta = enum.auto()
Ephi = enum.auto()
Hr = enum.auto()
Htheta = enum.auto()
Hphi = enum.auto()
# Optics
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
Flux = enum.auto()
PermXX = enum.auto()
PermYY = enum.auto()
PermZZ = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
# Generic
Expr = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def to_name(v: typ.Self) -> str: def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name. """Convert the enum value to a human-friendly name.
@ -50,27 +91,6 @@ class SimSymbolName(enum.StrEnum):
""" """
return SimSymbolName(v).name return SimSymbolName(v).name
@property
def name(self) -> str:
SSN = SimSymbolName
return {
SSN.LowerA: 'a',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
@staticmethod @staticmethod
def to_icon(_: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
"""Convert the enum value to a Blender icon. """Convert the enum value to a Blender icon.
@ -83,6 +103,75 @@ class SimSymbolName(enum.StrEnum):
""" """
return '' return ''
####################
# - Computed Properties
####################
@property
def name(self) -> str:
SSN = SimSymbolName
return {
# Lower
SSN.LowerA: 'a',
SSN.LowerB: 'b',
SSN.LowerC: 'c',
SSN.LowerD: 'd',
SSN.LowerI: 'i',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
# Fields
SSN.Ex: 'Ex',
SSN.Ey: 'Ey',
SSN.Ez: 'Ez',
SSN.Hx: 'Hx',
SSN.Hy: 'Hy',
SSN.Hz: 'Hz',
SSN.Er: 'Ex',
SSN.Etheta: 'Ey',
SSN.Ephi: 'Ez',
SSN.Hr: 'Hx',
SSN.Htheta: 'Hy',
SSN.Hphi: 'Hz',
# Optics
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
SSN.Flux: 'flux',
SSN.PermXX: 'eps_xx',
SSN.PermYY: 'eps_yy',
SSN.PermZZ: 'eps_zz',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
# Generic
SSN.Expr: 'expr',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
####################
# - Simulation Symbol
####################
def mk_interval(
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
interval_inf: tuple[bool, bool],
interval_closed: tuple[bool, bool],
unit_factor: typ.Literal[1] | spux.Unit,
) -> sp.Interval:
"""Create a symbolic interval from the tuples (and unit) defining it."""
return sp.Interval(
start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
left_open=(True if interval_inf[0] else not interval_closed[0]),
right_open=(True if interval_inf[1] else not interval_closed[1]),
)
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
class SimSymbol: class SimSymbol:
@ -94,66 +183,145 @@ class SimSymbol:
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
""" """
sim_node_name: SimSymbolName = SimSymbolName.LowerX sym_name: SimSymbolName
mathtype: spux.MathType = spux.MathType.Real mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
## TODO: Shape/size support? Incl. MatrixSymbol. # Units
## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way.
unit: spux.Unit | None = None
# Domain # Size
interval_finite: tuple[float, float] = (0, 1) ## -> All SimSymbol sizes are "2D", but interpreted by convention.
## -> 1x1: "Scalar".
## -> nx1: "Vector".
## -> 1xn: "Covector".
## -> nxn: "Matrix".
rows: int = 1
cols: int = 1
# Scalar Domain: "Interval"
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI.
interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0, 1)
interval_inf: tuple[bool, bool] = (True, True) interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False) interval_closed: tuple[bool, bool] = (False, False)
interval_finite_im: tuple[float, float] = (0, 1)
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
#################### ####################
# - Properties # - Properties
#################### ####################
@property @property
def name(self) -> str: def name(self) -> str:
return self.sim_node_name.name """Usable name for the symbol."""
return self.sym_name.name
@property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
return self.sym_name.name_pretty
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
@property
def plot_label(self) -> str:
"""Pretty plot-oriented label."""
return f'{self.name_pretty}' + (
f'({self.unit})' if self.unit is not None else ''
)
@property
def unit_factor(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return self.unit if self.unit is not None else sp.S(1)
@property
def shape(self) -> tuple[int, ...]:
match (self.rows, self.cols):
case (1, 1):
return ()
case (_, 1):
return (self.rows,)
case (1, _):
return (1, self.rows)
case (_, _):
return (self.rows, self.cols)
@property @property
def domain(self) -> sp.Interval | sp.Set: def domain(self) -> sp.Interval | sp.Set:
"""Return the domain of valid values for the symbol. """Return the scalar domain of valid values for each element of the symbol.
For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties. For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
This interval **must** have the property`start <= stop`. This interval **must** have the property`start <= stop`.
Otherwise, the domain is the symbolic set corresponding to `self.mathtype`. Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
""" """
if self.mathtype in [ match self.mathtype:
spux.MathType.Integer, case spux.MathType.Integer:
spux.MathType.Rational, return mk_interval(
spux.MathType.Real, self.interval_finite_z,
]: self.interval_inf,
return sp.Interval( self.interval_closed,
start=self.interval_finite[0] if not self.interval_inf[0] else -sp.oo, self.unit_factor,
end=self.interval_finite[1] if not self.interval_inf[1] else sp.oo,
left_open=(
True if self.interval_inf[0] else not self.interval_closed[0]
),
right_open=(
True if self.interval_inf[1] else not self.interval_closed[1]
),
) )
return self.mathtype.symbolic_set case spux.MathType.Rational:
return mk_interval(
Fraction(*self.interval_finite_q),
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Real:
return mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Complex:
return (
mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
),
mk_interval(
self.interval_finite_im,
self.interval_inf_im,
self.interval_closed_im,
self.unit_factor,
),
)
#################### ####################
# - Properties # - Properties
#################### ####################
@property @property
def sp_symbol(self) -> sp.Symbol: def sp_symbol(self) -> sp.Symbol:
"""Return a symbolic variable corresponding to this `SimSymbol`. """Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined. As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
However, the assumptions system alone is rather limited, and implementations should therefore also strongly consider transporting `SimSymbols` directly, instead of `sp.Symbol`. - **MathType**: Depending on `self.mathtype`.
This allows making use of other properties like `self.domain`, when appropriate. - **Positive/Negative**: Depending on `self.domain`.
- **Nonzero**: Depending on `self.domain`, including open/closed boundary specifications.
Notes:
**The assumptions system is rather limited**, and implementations should strongly consider transporting `SimSymbols` instead of `sp.Symbol`.
This allows tracking ex. the valid interval domain for a symbol.
""" """
# MathType Domain Constraint # MathType Assumption
## -> We must feed the assumptions system.
mathtype_kwargs = {} mathtype_kwargs = {}
match self.mathtype: match self.mathtype:
case spux.MathType.Integer: case spux.MathType.Integer:
@ -165,9 +333,7 @@ class SimSymbol:
case spux.MathType.Complex: case spux.MathType.Complex:
mathtype_kwargs |= {'complex': True} mathtype_kwargs |= {'complex': True}
# Interval Constraints # Non-Zero Assumption
if isinstance(self.domain, sp.Interval):
# Assumption: Non-Zero
if ( if (
( (
self.domain.left == 0 self.domain.left == 0
@ -180,38 +346,125 @@ class SimSymbol:
): ):
mathtype_kwargs |= {'nonzero': True} mathtype_kwargs |= {'nonzero': True}
# Assumption: Positive/Negative # Positive/Negative Assumption
if self.domain.left >= 0: if self.domain.left >= 0:
mathtype_kwargs |= {'positive': True} mathtype_kwargs |= {'positive': True}
elif self.domain.right <= 0: elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True} mathtype_kwargs |= {'negative': True}
# Construct the Symbol return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
return sp.Symbol(self.sim_node_name.name, **mathtype_kwargs)
####################
# - Operations
####################
def update(self, **kwargs) -> typ.Self:
def get_attr(attr: str):
_notfound = 'notfound'
if kwargs.get(attr, _notfound) is _notfound:
return getattr(self, attr)
return kwargs[attr]
return SimSymbol(
sym_name=get_attr('sym_name'),
mathtype=get_attr('mathtype'),
physical_type=get_attr('physical_type'),
unit=get_attr('unit'),
rows=get_attr('rows'),
cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_q'),
interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'),
interval_inf_im=get_attr('interval_inf_im'),
interval_closed_im=get_attr('interval_closed_im'),
)
def set_size(self, rows: int, cols: int) -> typ.Self:
return SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=rows,
cols=cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
)
#################### ####################
# - Common Sim Symbols # - Common Sim Symbols
#################### ####################
class CommonSimSymbol(enum.StrEnum): class CommonSimSymbol(enum.StrEnum):
"""A set of pre-defined symbols that might commonly be used in the context of physical simulation. """Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
Each entry maps directly to a particular `SimSymbol`. The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
The enum is compatible with `BLField`, making it easy to declare a UI-driven dropdown of symbols that behave as expected.
Attributes: Attributes:
X:
Time: A symbol representing a real-valued wavelength.
Wavelength: A symbol representing a real-valued wavelength. Wavelength: A symbol representing a real-valued wavelength.
Implicitly, this symbol often represents "vacuum wavelength" in particular. Implicitly, this symbol often represents "vacuum wavelength" in particular.
Wavelength: A symbol representing a real-valued frequency. Wavelength: A symbol representing a real-valued frequency.
Generally, this is the non-angular frequency. Generally, this is the non-angular frequency.
""" """
X = enum.auto() Index = enum.auto()
# Space|Time
SpaceX = enum.auto()
SpaceY = enum.auto()
SpaceZ = enum.auto()
AngR = enum.auto()
AngTheta = enum.auto()
AngPhi = enum.auto()
DirX = enum.auto()
DirY = enum.auto()
DirZ = enum.auto()
Time = enum.auto() Time = enum.auto()
# Fields
FieldEx = enum.auto()
FieldEy = enum.auto()
FieldEz = enum.auto()
FieldHx = enum.auto()
FieldHy = enum.auto()
FieldHz = enum.auto()
FieldEr = enum.auto()
FieldEtheta = enum.auto()
FieldEphi = enum.auto()
FieldHr = enum.auto()
FieldHtheta = enum.auto()
FieldHphi = enum.auto()
# Optics
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
Flux = enum.auto()
WaveVecX = enum.auto()
WaveVecY = enum.auto()
WaveVecZ = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def to_name(v: typ.Self) -> str: def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name. """Convert the enum value to a human-friendly name.
@ -222,7 +475,7 @@ class CommonSimSymbol(enum.StrEnum):
Returns: Returns:
A human-friendly name corresponding to the enum value. A human-friendly name corresponding to the enum value.
""" """
return CommonSimSymbol(v).sim_symbol_name.name return CommonSimSymbol(v).name
@staticmethod @staticmethod
def to_icon(_: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
@ -241,55 +494,125 @@ class CommonSimSymbol(enum.StrEnum):
#################### ####################
@property @property
def name(self) -> str: def name(self) -> str:
return self.sim_symbol.name
@property
def sim_symbol_name(self) -> str:
SSN = SimSymbolName SSN = SimSymbolName
CSS = CommonSimSymbol CSS = CommonSimSymbol
return { return {
CSS.X: SSN.LowerX, CSS.Index: SSN.LowerI,
# Space|Time
CSS.SpaceX: SSN.LowerX,
CSS.SpaceY: SSN.LowerY,
CSS.SpaceZ: SSN.LowerZ,
CSS.AngR: SSN.LowerR,
CSS.AngTheta: SSN.LowerTheta,
CSS.AngPhi: SSN.LowerPhi,
CSS.DirX: SSN.LowerX,
CSS.DirY: SSN.LowerY,
CSS.DirZ: SSN.LowerZ,
CSS.Time: SSN.LowerT, CSS.Time: SSN.LowerT,
CSS.Wavelength: SSN.Wavelength, # Fields
CSS.FieldEx: SSN.Ex,
CSS.FieldEy: SSN.Ey,
CSS.FieldEz: SSN.Ez,
CSS.FieldHx: SSN.Hx,
CSS.FieldHy: SSN.Hy,
CSS.FieldHz: SSN.Hz,
CSS.FieldEr: SSN.Er,
CSS.FieldHr: SSN.Hr,
# Optics
CSS.Frequency: SSN.Frequency, CSS.Frequency: SSN.Frequency,
CSS.Wavelength: SSN.Wavelength,
CSS.DiffOrderX: SSN.DiffOrderX,
CSS.DiffOrderY: SSN.DiffOrderY,
}[self] }[self]
@property def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
def sim_symbol(self) -> SimSymbol:
"""Retrieve the `SimSymbol` associated with the `CommonSimSymbol`.""" """Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
CSS = CommonSimSymbol CSS = CommonSimSymbol
# Space
sym_space = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Length,
unit=unit,
)
sym_ang = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Angle,
unit=unit,
)
# Fields
def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
return SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.EField
if eh == 'e'
else spux.PhysicalType.HField,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf_re=(False, True),
interval_closed_re=(True, False),
interval_finite_im=(sys.float_info.min, sys.float_info.max),
interval_inf_im=(True, True),
)
return { return {
CSS.X: SimSymbol( CSS.Index: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Integer,
physical_type=spux.PhysicalType.NonPhysical, interval_finite_z=(0, 2**64),
## TODO: Unit of Picosecond
interval_finite=(sys.float_info.min, sys.float_info.max),
interval_inf=(True, True),
interval_closed=(False, False),
),
CSS.Time: SimSymbol(
sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Time,
## TODO: Unit of Picosecond
interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(True, False), interval_closed=(True, False),
), ),
# Space|Time
CSS.SpaceX: sym_space,
CSS.SpaceY: sym_space,
CSS.SpaceZ: sym_space,
CSS.AngR: sym_space,
CSS.AngTheta: sym_ang,
CSS.AngPhi: sym_ang,
CSS.Time: SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Time,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Fields
CSS.FieldEx: sym_field('e'),
CSS.FieldEy: sym_field('e'),
CSS.FieldEz: sym_field('e'),
CSS.FieldHx: sym_field('h'),
CSS.FieldHy: sym_field('h'),
CSS.FieldHz: sym_field('h'),
CSS.FieldEr: sym_field('e'),
CSS.FieldEtheta: sym_field('e'),
CSS.FieldEphi: sym_field('e'),
CSS.FieldHr: sym_field('h'),
CSS.FieldHtheta: sym_field('h'),
CSS.FieldHphi: sym_field('h'),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
# Optics
CSS.Wavelength: SimSymbol( CSS.Wavelength: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length, physical_type=spux.PhysicalType.Length,
## TODO: Unit of Picosecond unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
CSS.Frequency: SimSymbol( CSS.Frequency: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
@ -298,9 +621,33 @@ class CommonSimSymbol(enum.StrEnum):
#################### ####################
# - Selected Direct Access # - Selected Direct-Access to SimSymbols
#################### ####################
x = CommonSimSymbol.X.sim_symbol idx = CommonSimSymbol.Index.sim_symbol
t = CommonSimSymbol.Time.sim_symbol t = CommonSimSymbol.Time.sim_symbol
wl = CommonSimSymbol.Wavelength.sim_symbol wl = CommonSimSymbol.Wavelength.sim_symbol
freq = CommonSimSymbol.Frequency.sim_symbol freq = CommonSimSymbol.Frequency.sim_symbol
space_x = CommonSimSymbol.SpaceX.sim_symbol
space_y = CommonSimSymbol.SpaceY.sim_symbol
space_z = CommonSimSymbol.SpaceZ.sim_symbol
dir_x = CommonSimSymbol.DirX.sim_symbol
dir_y = CommonSimSymbol.DirY.sim_symbol
dir_z = CommonSimSymbol.DirZ.sim_symbol
ang_r = CommonSimSymbol.AngR.sim_symbol
ang_theta = CommonSimSymbol.AngTheta.sim_symbol
ang_phi = CommonSimSymbol.AngPhi.sim_symbol
field_ex = CommonSimSymbol.FieldEx.sim_symbol
field_ey = CommonSimSymbol.FieldEy.sim_symbol
field_ez = CommonSimSymbol.FieldEz.sim_symbol
field_hx = CommonSimSymbol.FieldHx.sim_symbol
field_hy = CommonSimSymbol.FieldHx.sim_symbol
field_hz = CommonSimSymbol.FieldHx.sim_symbol
flux = CommonSimSymbol.Flux.sim_symbol
diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol