Compare commits

...

4 Commits

Author SHA1 Message Date
Sofus Albert Høgsbro Rose 353a2c997e
refactor: end-of-day commit (sim symbol flow for data import/export & inverse design) 2024-05-21 22:57:56 +02:00
Sofus Albert Høgsbro Rose dccf952ad3
refactor: renamed LazyValueFunc to Func 2024-05-21 09:04:05 +02:00
Sofus Albert Høgsbro Rose 84825c2642
refactor: Renamed LazyArrayRange to Range 2024-05-21 09:00:23 +02:00
Sofus Albert Høgsbro Rose f5d19abecd
feat: data file exporter node
As with many things, there seems to be an obvious convergent design
philosophy here wrt. data flow.
We should definitely be using `SimSymbol` for a lot more things; it
solves a lot of the pain points related to figuring out what on Earth
should go into the `InfoFlow` in which situations.

We desperately need to iron out the `*Flow` object semantics

The surprise MVP of the day is `Polars`.
What a gorgeous and fast dataframe library.
We initially wrote it off as being unsuited to multidimensional data,
but case in point, a whole lot of useful data can indeed be expressed as 2D.
For all of these cases, be it loading/saving or processing, `Polars`
is truly an ideal choice.

Work continues.
2024-05-21 08:51:26 +02:00
42 changed files with 3232 additions and 2146 deletions

View File

@ -21,11 +21,13 @@ dependencies = [
# Pin Blender 4.1.0-Compatible Versions # Pin Blender 4.1.0-Compatible Versions
## The dependency resolver will report if anything is wonky. ## The dependency resolver will report if anything is wonky.
"urllib3==1.26.8", "urllib3==1.26.8",
#"requests==2.27.1", ## Conflict with dev-dep commitizen #"requests==2.27.1", ## Conflict with dev-dep commitizen
"numpy==1.24.3", "numpy==1.24.3",
"idna==3.3", "idna==3.3",
#"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen #"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen
"certifi==2021.10.8", "certifi==2021.10.8",
"polars>=0.20.26",
"seaborn[stats]>=0.13.2",
] ]
## When it comes to dev-dep conflicts: ## When it comes to dev-dep conflicts:
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them. ## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.

View File

@ -81,6 +81,7 @@ locket==1.0.0
markupsafe==2.1.5 markupsafe==2.1.5
# via jinja2 # via jinja2
matplotlib==3.8.3 matplotlib==3.8.3
# via seaborn
# via tidy3d # via tidy3d
ml-dtypes==0.4.0 ml-dtypes==0.4.0
# via jax # via jax
@ -102,8 +103,11 @@ numpy==1.24.3
# via ml-dtypes # via ml-dtypes
# via numba # via numba
# via opt-einsum # via opt-einsum
# via patsy
# via scipy # via scipy
# via seaborn
# via shapely # via shapely
# via statsmodels
# via tidy3d # via tidy3d
# via trimesh # via trimesh
# via xarray # via xarray
@ -114,15 +118,21 @@ packaging==24.0
# via dask # via dask
# via h5netcdf # via h5netcdf
# via matplotlib # via matplotlib
# via statsmodels
# via xarray # via xarray
pandas==2.2.1 pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray # via xarray
partd==1.4.1 partd==1.4.1
# via dask # via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0 pillow==10.2.0
# via matplotlib # via matplotlib
platformdirs==4.2.1 platformdirs==4.2.1
# via virtualenv # via virtualenv
polars==0.20.26
pre-commit==3.7.0 pre-commit==3.7.0
prompt-toolkit==3.0.36 prompt-toolkit==3.0.36
# via questionary # via questionary
@ -166,13 +176,19 @@ s3transfer==0.5.2
scipy==1.12.0 scipy==1.12.0
# via jax # via jax
# via jaxlib # via jaxlib
# via seaborn
# via statsmodels
# via tidy3d # via tidy3d
seaborn==0.13.2
setuptools==69.5.1 setuptools==69.5.1
# via nodeenv # via nodeenv
shapely==2.0.3 shapely==2.0.3
# via tidy3d # via tidy3d
six==1.16.0 six==1.16.0
# via patsy
# via python-dateutil # via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12 sympy==1.12
termcolor==2.4.0 termcolor==2.4.0
# via commitizen # via commitizen

View File

@ -59,6 +59,7 @@ llvmlite==0.42.0
locket==1.0.0 locket==1.0.0
# via partd # via partd
matplotlib==3.8.3 matplotlib==3.8.3
# via seaborn
# via tidy3d # via tidy3d
ml-dtypes==0.4.0 ml-dtypes==0.4.0
# via jax # via jax
@ -78,8 +79,11 @@ numpy==1.24.3
# via ml-dtypes # via ml-dtypes
# via numba # via numba
# via opt-einsum # via opt-einsum
# via patsy
# via scipy # via scipy
# via seaborn
# via shapely # via shapely
# via statsmodels
# via tidy3d # via tidy3d
# via trimesh # via trimesh
# via xarray # via xarray
@ -89,13 +93,19 @@ packaging==24.0
# via dask # via dask
# via h5netcdf # via h5netcdf
# via matplotlib # via matplotlib
# via statsmodels
# via xarray # via xarray
pandas==2.2.1 pandas==2.2.1
# via seaborn
# via statsmodels
# via xarray # via xarray
partd==1.4.1 partd==1.4.1
# via dask # via dask
patsy==0.5.6
# via statsmodels
pillow==10.2.0 pillow==10.2.0
# via matplotlib # via matplotlib
polars==0.20.26
pydantic==2.7.1 pydantic==2.7.1
# via tidy3d # via tidy3d
pydantic-core==2.18.2 pydantic-core==2.18.2
@ -131,11 +141,17 @@ s3transfer==0.5.2
scipy==1.12.0 scipy==1.12.0
# via jax # via jax
# via jaxlib # via jaxlib
# via seaborn
# via statsmodels
# via tidy3d # via tidy3d
seaborn==0.13.2
shapely==2.0.3 shapely==2.0.3
# via tidy3d # via tidy3d
six==1.16.0 six==1.16.0
# via patsy
# via python-dateutil # via python-dateutil
statsmodels==0.14.2
# via seaborn
sympy==1.12 sympy==1.12
tidy3d==2.6.3 tidy3d==2.6.3
toml==0.10.2 toml==0.10.2

View File

@ -41,6 +41,9 @@ class OperatorType(enum.StrEnum):
SocketCloudAuthenticate = enum.auto() SocketCloudAuthenticate = enum.auto()
SocketReloadCloudFolderList = enum.auto() SocketReloadCloudFolderList = enum.auto()
# Node: ExportDataFile
NodeExportDataFile = enum.auto()
# Node: Tidy3DWebImporter # Node: Tidy3DWebImporter
NodeLoadCloudSim = enum.auto() NodeLoadCloudSim = enum.auto()

View File

@ -47,8 +47,8 @@ from .flow_kinds import (
CapabilitiesFlow, CapabilitiesFlow,
FlowKind, FlowKind,
InfoFlow, InfoFlow,
LazyArrayRangeFlow, RangeFlow,
LazyValueFuncFlow, FuncFlow,
ParamsFlow, ParamsFlow,
ScalingMode, ScalingMode,
ValueFlow, ValueFlow,
@ -59,6 +59,7 @@ from .mobj_types import ManagedObjType
from .node_types import NodeType from .node_types import NodeType
from .sim_types import ( from .sim_types import (
BoundCondType, BoundCondType,
DataFileFormat,
NewSimCloudTask, NewSimCloudTask,
SimAxisDir, SimAxisDir,
SimFieldPols, SimFieldPols,
@ -103,6 +104,7 @@ __all__ = [
'BLSocketType', 'BLSocketType',
'NodeType', 'NodeType',
'BoundCondType', 'BoundCondType',
'DataFileFormat',
'NewSimCloudTask', 'NewSimCloudTask',
'SimAxisDir', 'SimAxisDir',
'SimFieldPols', 'SimFieldPols',
@ -116,8 +118,8 @@ __all__ = [
'CapabilitiesFlow', 'CapabilitiesFlow',
'FlowKind', 'FlowKind',
'InfoFlow', 'InfoFlow',
'LazyArrayRangeFlow', 'RangeFlow',
'LazyValueFuncFlow', 'FuncFlow',
'ParamsFlow', 'ParamsFlow',
'ScalingMode', 'ScalingMode',
'ValueFlow', 'ValueFlow',

View File

@ -18,8 +18,8 @@ from .array import ArrayFlow
from .capabilities import CapabilitiesFlow from .capabilities import CapabilitiesFlow
from .flow_kinds import FlowKind from .flow_kinds import FlowKind
from .info import InfoFlow from .info import InfoFlow
from .lazy_array_range import LazyArrayRangeFlow, ScalingMode from .lazy_range import RangeFlow, ScalingMode
from .lazy_value_func import LazyValueFuncFlow from .lazy_func import FuncFlow
from .params import ParamsFlow from .params import ParamsFlow
from .value import ValueFlow from .value import ValueFlow
@ -28,9 +28,9 @@ __all__ = [
'CapabilitiesFlow', 'CapabilitiesFlow',
'FlowKind', 'FlowKind',
'InfoFlow', 'InfoFlow',
'LazyArrayRangeFlow', 'RangeFlow',
'ScalingMode', 'ScalingMode',
'LazyValueFuncFlow', 'FuncFlow',
'ParamsFlow', 'ParamsFlow',
'ValueFlow', 'ValueFlow',
] ]

View File

@ -29,9 +29,12 @@ from blender_maxwell.utils import logger
log = logger.get(__name__) log = logger.get(__name__)
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class ArrayFlow: class ArrayFlow:
"""A simple, flat array of values with an optionally-attached unit. """A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
Attributes: Attributes:
values: An ND array-like object of arbitrary numerical type. values: An ND array-like object of arbitrary numerical type.
@ -44,13 +47,97 @@ class ArrayFlow:
is_sorted: bool = False is_sorted: bool = False
####################
# - Computed Properties
####################
@property
def is_symbolic(self) -> bool:
"""Always False, as ArrayFlows are never unrealized."""
return False
def __len__(self) -> int: def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values) return len(self.values)
@functools.cached_property @functools.cached_property
def mathtype(self) -> spux.MathType: def mathtype(self) -> spux.MathType:
"""Deduce the `spux.MathType` of the first element of the contained array.
This is generally a heuristic, but because `jax` enforces homogeneous arrays, this is actually a well-defined approach.
"""
return spux.MathType.from_pytype(type(self.values.item(0))) return spux.MathType.from_pytype(type(self.values.item(0)))
@functools.cached_property
def physical_type(self) -> spux.MathType:
"""Deduce the `spux.PhysicalType` of the unit."""
return spux.PhysicalType.from_unit(self.unit)
####################
# - Array Features
####################
@property
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
"""Standardized access to `self.values`."""
return self.values
@functools.cached_property
def shape(self) -> int:
"""Shape of the contained array."""
return self.values.shape
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
"""Implement indexing and slicing in a sane way.
- **Integer Index**: For scalar output, return a `sympy` expression of the scalar multiplied by the unit, else just a sympy expression of the value.
- **Slice**: Slice the internal array directly, and wrap the result in a new `ArrayFlow`.
"""
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
if isinstance(subscript, int):
value = self.values[subscript]
if len(value.shape) == 0:
return value * self.unit if self.unit is not None else sp.S(value)
return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
raise NotImplementedError
####################
# - Methods
####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
"""Apply an order-preserving function to each element of the array, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
# Compile JAX-Compatible Rescale Function
a = self.mathtype.sp_symbol_a
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value. """Find the index of the value that is closest to the given value.
@ -88,56 +175,26 @@ class ArrayFlow:
return right_idx return right_idx
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self: ####################
if self.unit is not None: # - Unit Transforms
return ArrayFlow( ####################
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted def correct_unit(self, unit: spux.Unit) -> typ.Self:
) """Simply replace the existing unit with the given one.
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"' Parameters:
raise ValueError(msg) corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
def rescale_to_unit(self, unit: spu.Quantity | None) -> typ.Self: def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
## TODO: Cache by unit would be a very nice speedup for Viz node. """Rescale the `ArrayFlow` to be expressed in the given unit.
if self.unit is not None:
return ArrayFlow(
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
unit=unit,
is_sorted=self.is_sorted,
)
if unit is None: Parameters:
return self corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' """
raise ValueError(msg) return self.rescale(lambda v: v, new_unit=new_unit)
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
# Compile JAX-Compatible Rescale Function
a = sp.Symbol('a')
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a * self.unit)
)
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
raise NotImplementedError raise NotImplementedError

View File

@ -0,0 +1,36 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux
from . import FlowKind
class ExprInfo(typ.TypedDict):
active_kind: FlowKind
size: spux.NumberSize1D
mathtype: spux.MathType
physical_type: spux.PhysicalType
# Value
default_value: spux.SympyExpr
# Range
default_min: spux.SympyExpr
default_max: spux.SympyExpr
default_steps: int

View File

@ -19,6 +19,7 @@ import typing as typ
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from blender_maxwell.utils.staticproperty import staticproperty
log = logger.get(__name__) log = logger.get(__name__)
@ -40,9 +41,9 @@ class FlowKind(enum.StrEnum):
Array: An object with dimensions, and possibly a unit. Array: An object with dimensions, and possibly a unit.
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array` Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually. However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
LazyValueFunc: A composable function. Func: A composable function.
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance. Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing). Range: An object that generates an `Array` from range information (start/stop/step/spacing).
This should be used instead of `Array` whenever possible. This should be used instead of `Array` whenever possible.
Param: A dictionary providing particular parameters for a lazy value. Param: A dictionary providing particular parameters for a lazy value.
Info: An dictionary providing extra context about any aspect of flow. Info: An dictionary providing extra context about any aspect of flow.
@ -51,16 +52,71 @@ class FlowKind(enum.StrEnum):
Capabilities = enum.auto() Capabilities = enum.auto()
# Values # Values
Value = enum.auto() Value = enum.auto() ## 'value'
Array = enum.auto() Array = enum.auto() ## 'array'
# Lazy # Lazy
LazyValueFunc = enum.auto() Func = enum.auto() ## 'lazy_func'
LazyArrayRange = enum.auto() Range = enum.auto() ## 'lazy_range'
# Auxiliary # Auxiliary
Params = enum.auto() Params = enum.auto() ## 'params'
Info = enum.auto() Info = enum.auto() ## 'info'
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
# Values
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
# Lazy
FlowKind.Range: 'Range',
FlowKind.Func: 'Func',
# Auxiliary
FlowKind.Params: 'Params',
FlowKind.Info: 'Info',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
####################
# - Static Properties
####################
@staticproperty
def active_kinds() -> list[typ.Self]:
"""Return a list of `FlowKind`s that are able to be considered "active".
"Active" `FlowKind`s are considered the primary data type of a socket's flow.
For example, for sockets to be linkeable, their active `FlowKind` must generally match.
"""
return [
FlowKind.Value,
FlowKind.Array,
FlowKind.Range,
FlowKind.Func,
]
@property
def socket_shape(self) -> str:
"""Return the socket shape associated with this `FlowKind`.
**ONLY** valid for `FlowKind`s that can be considered "active".
Raises:
ValueError: If this `FlowKind` cannot ever be considered "active".
"""
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.Range: 'SQUARE',
FlowKind.Func: 'DIAMOND',
}[self]
#################### ####################
# - Class Methods # - Class Methods
@ -69,7 +125,7 @@ class FlowKind(enum.StrEnum):
def scale_to_unit_system( def scale_to_unit_system(
cls, cls,
kind: typ.Self, kind: typ.Self,
flow_obj, flow_obj: spux.SympyExpr,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem,
): ):
# log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj)) # log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
@ -79,7 +135,7 @@ class FlowKind(enum.StrEnum):
flow_obj, flow_obj,
unit_system, unit_system,
) )
if kind == FlowKind.LazyArrayRange: if kind == FlowKind.Range:
return flow_obj.rescale_to_unit_system(unit_system) return flow_obj.rescale_to_unit_system(unit_system)
if kind == FlowKind.Params: if kind == FlowKind.Params:
@ -87,43 +143,3 @@ class FlowKind(enum.StrEnum):
msg = 'Tried to scale unknown kind' msg = 'Tried to scale unknown kind'
raise ValueError(msg) raise ValueError(msg)
####################
# - Computed
####################
@property
def flow_kind(self) -> str:
return {
FlowKind.Value: FlowKind.Value,
FlowKind.Array: FlowKind.Array,
FlowKind.LazyValueFunc: FlowKind.LazyValueFunc,
FlowKind.LazyArrayRange: FlowKind.LazyArrayRange,
}[self]
@property
def socket_shape(self) -> str:
return {
FlowKind.Value: 'CIRCLE',
FlowKind.Array: 'SQUARE',
FlowKind.LazyArrayRange: 'SQUARE',
FlowKind.LazyValueFunc: 'DIAMOND',
}[self]
####################
# - Blender Enum
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return {
FlowKind.Capabilities: 'Capabilities',
FlowKind.Value: 'Value',
FlowKind.Array: 'Array',
FlowKind.LazyArrayRange: 'Range',
FlowKind.LazyValueFunc: 'Func',
FlowKind.Params: 'Parameters',
FlowKind.Info: 'Information',
}[v]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''

View File

@ -18,246 +18,247 @@ import dataclasses
import functools import functools
import typing as typ import typing as typ
import jax
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .array import ArrayFlow from .array import ArrayFlow
from .lazy_array_range import LazyArrayRangeFlow from .lazy_range import RangeFlow
log = logger.get(__name__) log = logger.get(__name__)
LabelArray: typ.TypeAlias = list[str]
# IndexArray: Identifies Discrete Dimension Values
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
## -> RangeFlow (rat|real): Index by unrealized array scaled between boundaries.
## -> LabelArray (int): For int-index arrays, interpret using these labels.
## -> None: Non-Discrete/unrealized indexing; use 'dim.domain'.
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow: class InfoFlow:
#################### """Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
# - Covariant Input
####################
dim_names: list[str] = dataclasses.field(default_factory=list)
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
default_factory=dict
) ## TODO: Rename to dim_idxs
@functools.cached_property Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
def dim_has_coords(self) -> dict[str, int]:
return {
dim_name: not (
isinstance(dim_idx, LazyArrayRangeFlow)
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property # From Arrays to Tensors
def dim_lens(self) -> dict[str, int]: The best way to illustrate how it works is to specify how raw-array concepts map to an array described by an `InfoFlow`:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property - **Index**: In raw arrays, the "index" is generally constrained to an integer ring, and has no semantic meaning.
def dim_mathtypes(self) -> dict[str, spux.MathType]: **(Covariant) Dimension**: The "dimension" is an named "index array", which assigns each integer index a **scalar** value of particular mathematical type, name, and unit (if not unitless).
return { - **Value**: In raw arrays, the "value" is some particular computational type, or another raw array.
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() **(Contravariant) Output**: The "output" is a strictly named, sized object that can only be produced
}
@functools.cached_property In essence, `InfoFlow` allows us to treat raw data as a tensor, then operate on its dimensionality as split into parts whose transform varies _with_ the output (aka. a _covariant_ index), and parts whose transform varies _against_ the output (aka. _contravariant_ value).
def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property ## Benefits
def dim_physical_types(self) -> dict[str, spux.PhysicalType]: The reasons to do this are numerous:
return {
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property - **Clarity**: Using `InfoFlow`, it's easy to understand what the data is, and what can be done to it, making it much easier to implement complex operations in math nodes without sacrificing the user's mental model.
def dim_idx_arrays(self) -> list[jax.Array]: - **Zero-Cost Operations**: Transforming indices, "folding" dimensions into the output, and other such operations don't actually do anything to the data, enabling a lot of operations to feel "free" in terms of performance.
return [ - **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
dim_idx.realize().values """
if isinstance(dim_idx, LazyArrayRangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
#################### ####################
# - Contravariant Output # - Dimensions: Covariant Index
#################### ####################
# Output Information dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
@property
def output_shape_len(self) -> int:
if self.output_shape is None:
return 0
return len(self.output_shape)
# Pinned Dimension Information
## TODO: Add PhysicalType
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict default_factory=dict
) )
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict @functools.cached_property
) def last_dim(self) -> sim_symbols.SimSymbol | None:
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict) """The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return next(iter(self.dims.keys()))
return None
@functools.cached_property
def last_dim(self) -> sim_symbols.SimSymbol | None:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
if self.dims:
return list(self.dims.keys())[-1]
return None
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
Can be used to index `.shape` of the represented raw array.
"""
return list(self.dims.keys()).index(dim)
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
This happens when the dimension is generated from a symbolic function, as opposed to from discrete observations.
In these cases, the `SimSymbol.domain` of the dimension should be used to determine the overall domain of validity.
Other than that, it's up to the user to select a particular way of indexing.
"""
return self.dims[dim] is None
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim is indexed by a `LabelArray`."""
if dim.mathtype is spux.MathType.Integer:
return isinstance(self.dims[dim], list)
return False
#################### ####################
# - Methods # - Output: Contravariant Value
#################### ####################
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self: output: sim_symbols.SimSymbol
####################
# - Pinned Dimension Values
####################
## -> Whenever a dimension is deleted, we retain what that index value was.
## -> This proves to be very helpful for clear visualization.
pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
default_factory=dict
)
####################
# - Operations: Dimensions
####################
def prepend_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index 0."""
return InfoFlow( return InfoFlow(
# Dimensions dims={dim: dim_idx} | self.dims,
dim_names=self.dim_names, output=self.output,
dim_idx={ pinned_values=self.pinned_values,
_dim_name: ( )
dim_idx
if _dim_name != dim_name def slice_dim(
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]] self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
) ) -> typ.Self:
for _dim_name, dim_idx in self.dim_idx.items() """Slice a dimensional array by-index along a particular dimension."""
return InfoFlow(
dims={
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else _dim
for _dim, dim_idx in self.dims.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def replace_dim( def replace_dim(
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow] self,
old_dim: sim_symbols.SimSymbol,
new_dim: sim_symbols.SimSymbol,
new_dim_idx: IndexArray,
) -> typ.Self: ) -> typ.Self:
"""Replace a dimension (and its indexing) with a new name and index array/range.""" """Replace a dimension entirely, in-place, including symbol and index array."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=[ (new_dim if _dim == old_dim else _dim): (
dim_name if dim_name != old_dim_name else new_dim_idx[0] new_dim_idx if _dim == old_dim else _dim
for dim_name in self.dim_names
],
dim_idx={
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
) )
for dim_name, dim_idx in self.dim_idx.items() for _dim, dim_idx in self.dims.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self: def replace_dims(
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges.""" """Replace several dimensional indices with new index arrays/ranges."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=self.dim_names, dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
dim_idx={
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
for _dim_name, dim_idx in self.dim_idx.items()
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def delete_dimension(self, dim_name: str) -> typ.Self: def delete_dim(
"""Delete a dimension.""" self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
) -> typ.Self:
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
new_pin = (
{dim_to_remove: self.dims[dim_to_remove][pin_idx]}
if pin_idx is not None
else {}
)
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=[ dim: dim_idx
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name for dim, dim_idx in self.dims.items()
], if dim != dim_to_remove
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
}, },
# Outputs output=self.output,
output_name=self.output_name, pinned_values=self.pinned_values | new_pin,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self: def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
"""Swap the position of two dimensions.""" """Swap the positions of two dimensions."""
# Compute Swapped Dimension Name List # Swapped Dimension Keys
def name_swapper(dim_name): def name_swapper(dim_name):
return ( return (
dim_name dim_name
if dim_name not in [dim_0_name, dim_1_name] if dim_name not in [dim_0, dim_1]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name] else {dim_0: dim_1, dim_1: dim_0}[dim_name]
) )
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names] swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
# Compute Info
return InfoFlow( return InfoFlow(
# Dimensions dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
dim_names=dim_names, output=self.output,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names}, pinned_values=self.pinned_values,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self: ####################
"""Set the MathType of the output.""" # - Operations: Output
####################
def update_output(self, **kwargs) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
return InfoFlow( return InfoFlow(
dim_names=self.dim_names, dims=self.dims,
dim_idx=self.dim_idx, output=self.output.update(**kwargs),
# Outputs pinned_values=self.pinned_values,
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
) )
def collapse_output( ####################
self, # - Operations: Fold
collapsed_name: str, ####################
collapsed_mathtype: spux.MathType, def fold_last_input(self):
collapsed_unit: spux.Unit, """Fold the last input dimension into the output."""
) -> typ.Self: last_key = list(self.dims.keys())[-1]
"""Replace the (scalar) output with the given corrected values.""" last_idx = list(self.dims.values())[-1]
return InfoFlow(
# Dimensions rows = self.output.rows
dim_names=self.dim_names, cols = self.output.cols
dim_idx=self.dim_idx, match (rows, cols):
output_name=collapsed_name, case (1, 1):
output_shape=None, new_output = self.output.set_size(len(last_idx), 1)
output_mathtype=collapsed_mathtype, case (_, 1):
output_unit=collapsed_unit, new_output = self.output.set_size(rows, len(last_idx))
) case (1, _):
new_output = self.output.set_size(len(last_idx), cols)
case (_, _):
raise NotImplementedError ## Not yet :)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow( return InfoFlow(
# Dimensions dims={
dim_names=self.dim_names[:-1], dim: dim_idx for dim, dim_idx in self.dims.items() if dim != last_key
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
}, },
# Outputs output=new_output,
output_name=self.output_name, pinned_values=self.pinned_values,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
) )

View File

@ -0,0 +1,440 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
from .params import ParamsFlow
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class FuncFlow:
r"""Defines a flow of data as incremental function composition.
For specific math system usage instructions, please consult the documentation of relevant nodes.
# Introduction
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
Doing so has several advantages:
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy.
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact.
- **Performant**: Since no operations are happening, the UI feels fast and snappy.
## Strongly Related FlowKinds
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
- `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`.
- `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic.
For more, please see the documentation for each.
## Non-Mathematical Use
Of course, there are many interesting uses of incremental function composition that aren't mathematical.
For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel.
# Lazy Math: Theoretical Foundation
This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math.
Thus, it makes sense to describe the math system here.
## `depth=0`: Root Function
To start a composition chain, a function with no inputs must be defined as the "root", or "bottom".
$$
f_0:\ \ \ \ \biggl(
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
\underbrace{
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
...,
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
}_{\texttt{kwargs}}
\biggr) \to \text{output}_0
$$
In Python, such a construction would look like this:
```python
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
## 'A0', 'KV0' are of length 'p' and 'q'
def f_0(*args, **kwargs): ...
lazy_func_0 = FuncFlow(
func=f_0,
func_args=[(a_i, type(a_i)) for a_i in A0],
func_kwargs={k: v for k,v in KV0},
)
output_0 = lazy_func.func(*A0_computed, **KV0_computed)
```
## `depth>0`: Composition Chaining
So far, so easy.
Now, let's add a function that uses the result of $f_0$, without yet computing it.
$$
f_1:\ \ \ \ \biggl(
f_0(...),\ \
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
\underbrace{\biggl\{
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
\biggr) \to \text{output}_1
$$
Note:
- $f_1$ must take the arguments of both $f_0$ and $f_1$.
- The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments".
In other words, **there's suddenly a lot to manage**.
Even worse, the bigger the $n$, the more complexity we must real with.
This is where the Python version starts to show its purpose:
```python
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
## 'A1', 'KV1' are therefore of length 'r' and 's'
def f_1(output_0, *args, **kwargs): ...
lazy_func_1 = lazy_func_0.compose_within(
enclosing_func=f_1,
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
enclosing_func_kwargs={k: type(v) for k,v in K1},
)
A_computed = A0_computed + A1_computed
KW_computed = KV0_computed + KV1_computed
output_1 = lazy_func_1.func(*A_computed, **KW_computed)
```
By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this.
## `max depth`: "Realization"
So, we've composed a bunch of functions of functions of ...
We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object.
But it'd be pointless to just compose away forever.
We do actually need the data that they claim to compute now:
```python
# A_all and KW_all must be tracked on the side.
output_n = lazy_func_n.func(*A_all, **KW_all)
```
Of course, this comes with enormous overhead.
Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially.
That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all:
```python
output_n = lazy_func_n.func_jax(*A_all, **KW_all)
```
What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU.
With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices.
See the documentation of `Func.func_jax()` for more information on this process.
# Lazy Math: Practical Considerations
By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance.
## UX
UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective.
The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_.
Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**.
In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds.
Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software.
Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided.
In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough.
**This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need.
For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it.
It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis.
This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less.
Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes.
For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works).
To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises.
As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag.
In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters).
Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons.
With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better.
## Flexibility
Large-scale math is done on tensors, whether one knows (or likes!) it or not.
To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do:
- **Naming**: What _is_ each axis?
Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays.
Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted.
_Sometimes arrays with named axes are called "structured arrays".
- **Coordinates**: What do the indices of each axis really _mean_?
For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$.
But they can be limited to between index `23` to `298`.
I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough.
Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible.
In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain.
**Please consult the `InfoFlow` documentation for more**.
## Performance
All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`.
It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes.
Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data.
See the `ParamFlow` documentation for more information.
# Conclusion
There is, of course, a lot more to say about the math system in general.
A few teasers of what nodes can do with this system:
**Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`.
`jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability.
When used in
**Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes
We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations.
**Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_.
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
But above all, we hope that this math system is fun, practical, and maybe even interesting.
Attributes:
func: The function that generates the represented value.
func_args: The constrained identity of all positional arguments to the function.
func_kwargs: The constrained identity of all keyword arguments to the function.
supports_jax: Whether `self.func` can be compiled with JAX's JIT compiler.
See the documentation of `self.func_jax()`.
"""
func: LazyFunction
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=list
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=dict
)
## TODO: Use SimSymbol instead of the MathType|PT union.
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
## -- SimSymbol has more semantic meaning, including a name.
## -- If desired, SimSymbols could maybe even require a specific unit.
## It could greatly simplify a whole lot of pain associated with func_args.
supports_jax: bool = False
####################
# - Functions
####################
@functools.cached_property
def func_jax(self) -> LazyFunction:
"""Compile `self.func` into optimized XLA bytecode using `jax.jit`.
Not all functions can be compiled like this by `jax`.
A few critical criteria include:
- **Only JAX Ops**: All operations performed within the function must be explicitly compatible with `jax`, which generally means only using functions in `jax.lax`, `jax.numpy`
- **Known Shape**: The exact dimensions of the output, and of the inputs, must be known at `jit`-time.
In return, one receives:
- **Automatic Differentiation**: `jax` can robustly differentiate this function with respect to _any_ parameter.
This includes Jacobians and Hessians, forwards and backwards, real or complex, all with good numerical stability.
Since `tidy3d`'s simulator registers itself as `jax`-differentiable (using the adjoint method), this "autodiff" support can extend all the way from parameters in the simulation definition, to gradients of the simulation output.
When using these gradients for optimization, one achieves what is called "inverse design", where the desired properties of the output fields are used to automatically select simulation input parameters.
- **Performance**: XLA is a cross-industry project with the singular goal of providing a high-performance compilation target for data-driven programs.
Published architects of OpenXLA include Alibaba, Amazon Web Services, AMD, Apple, Arm, Google, Intel, Meta, and NVIDIA.
- **Device Agnosticism**: XLA bytecode runs not just on CPUs, but on massively parallel devices like GPUs and TPUs as well.
This enables massive speedups, and greatly expands the amount of data that is practical to work with at one time.
Notes:
The property `self.supports_jax` manually tracks whether these criteria are satisfied.
**As much as possible**, the _entirety of `blender_maxwell`_ is designed to maximize the ability to set `self.supports_jax = True` as often as possible.
**However**, there are many cases where a lazily-evaluated value is desirable, but `jax` isn't supported.
These include design space exploration, where any particular parameter might vary for the purpose of producing batched simulations.
In these cases, trying to compile a `self.func_jax` will raise a `ValueError`.
Returns:
The `jit`-compiled function, ready to run on CPU, GPU, or XLA.
Raises:
ValueError: If `self.supports_jax` is `False`.
References:
JAX JIT: <https://jax.readthedocs.io/en/latest/jit-compilation.html>
OpenXLA: <https://openxla.org/xla>
"""
if self.supports_jax:
return jax.jit(self.func)
msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)
####################
# - Realization
####################
def realize(
self,
params: ParamsFlow,
unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
) -> typ.Self:
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
return self.func(
*params.scaled_func_args(unit_system, symbol_values),
*params.scaled_func_kwargs(unit_system, symbol_values),
)
####################
# - Composition Operations
####################
def compose_within(
self,
enclosing_func: LazyFunction,
enclosing_func_args: list[type] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
supports_jax: bool = False,
) -> typ.Self:
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
This is the fundamental operation used to "chain" functions together.
Examples:
Consider a simple composition based on two expressions:
```python
R = spux.MathType.Real
C = spux.MathType.Complex
x, y = sp.symbols('x y', real=True)
# Prepare "Root" FuncFlow w/x,y args
expr_root = 3*x + y**2 - 100
expr_root_func = sp.lambdify([x, y], expr, 'jax')
func_root = FuncFlow(func=expr_root_func, func_args=[R,R], supports_jax=True)
# Compose "Enclosing" FuncFlow w/z arg
r = sp.Symbol('z', real=True)
z = sp.Symbol('z', complex=True)
expr = 10*sp.re(z) / (z + r)
expr_func = sp.lambdify([r, z], expr, 'jax')
func = func_root.compose_within(enclosing_func=expr_func, enclosing_func_args=[C])
# Compute 'expr_func(expr_root_func(10.0, -500.0), 1+8j)'
f.func_jax(10.0, -500.0, 1+8j)
```
Using this function, it's easy to "keep adding" symbolic functions of any kind to the chain, without introducing extraneous complexity or compromising the ease of calling the final function.
Returns:
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
"""
return FuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
),
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
supports_jax=self.supports_jax and supports_jax,
)
def __or__(
self,
other: typ.Self,
) -> typ.Self:
"""Create a lazy function that takes all arguments of both lazy-function inputs, and itself promises to return a 2-tuple containing the outputs of both inputs.
Generally, `self.func` produces a single array as output (when doing math, at least).
But sometimes (as in the `OperateMathNode`), we need to perform a binary operation between two arrays, like say, $+$.
Without realizing both `FuncFlow`s, it's not immediately obvious how one might accomplish this.
This overloaded function of the `|` operator (used as `left | right`) solves that problem.
A new `FuncFlow` is created, which takes the arguments of both inputs, and which produces a single output value: A 2-tuple, where each element if the output of each function.
Examples:
Consider this illustrative (pseudocode) example:
```python
# Presume a,b are values, and that A,B are their identifiers.
func_1 = FuncFlow(func=compute_big_data_1, func_args=[A])
func_2 = FuncFlow(func=compute_big_data_2, func_args=[B])
f = (func_1 | func_2).compose_within(func=lambda D: D[0] + D[1])
f.func(a, b) ## Computes big_data_1 + big_data_2 @A=a, B=b
```
Because of `__or__` (the operator `|`), the difficult and non-obvious task of adding the outputs of these unrealized functions because quite simple.
Notes:
**Order matters**.
`self` will be available in the new function's output as index `0`, while `other` will be available as index `1`.
As with anything lazy-composition-y, it can seem a bit strange at first.
When reading the source code, pay special attention to the way that `args` is sliced to segment the positional arguments.
Returns:
A lazy function that takes all arguments of both inputs, and returns a 2-tuple containing both output arguments.
"""
return FuncFlow(
func=lambda *args, **kwargs: (
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
other.func(
*list(args[len(self.func_args) :]),
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
)

View File

@ -28,13 +28,20 @@ from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
from .array import ArrayFlow from .array import ArrayFlow
from .flow_kinds import FlowKind from .lazy_func import FuncFlow
from .lazy_value_func import LazyValueFuncFlow
log = logger.get(__name__) log = logger.get(__name__)
class ScalingMode(enum.StrEnum): class ScalingMode(enum.StrEnum):
"""Identifier for how to space steps between two boundaries.
Attributes:
Lin: Uniform spacing between two endpoints.
Geom: Log spacing between two endpoints, given as values.
Log: Log spacing between two endpoints, given as powers of a common base.
"""
Lin = enum.auto() Lin = enum.auto()
Geom = enum.auto() Geom = enum.auto()
Log = enum.auto() Log = enum.auto()
@ -54,37 +61,21 @@ class ScalingMode(enum.StrEnum):
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class LazyArrayRangeFlow: class RangeFlow:
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation. r"""Represents a spaced array using symbolic boundary expressions.
# Advantages
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
## Memory # Memory Scaling
`ArrayFlow` generally has a memory scaling of $O(n)$. `ArrayFlow` generally has a memory scaling of $O(n)$.
Naturally, `LazyArrayRangeFlow` is always constant, since only the boundaries and steps are stored. Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored.
## Symbolic # Symbolic Bounds
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol). `self.start` and `self.stop` boundary points are symbolic expressions, within which any element of `self.symbols` can participate.
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow. **It is the user's responsibility** to ensure that `self.start < self.stop`.
## Performant Unit-Aware Operations # Numerical Properties
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
`LazyArrayRangeFlow`, by contrast, scales as $O(1)$.
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
## High-Performance Composition and Gradiant
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
There are two nice things about this:
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
Thus, as part of larger computations, the performance properties of `LazyArrayRangeFlow` is extremely favorable.
## Numerical Properties
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated. Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
Attributes: Attributes:
@ -108,8 +99,11 @@ class LazyArrayRangeFlow:
unit: spux.Unit | None = None unit: spux.Unit | None = None
symbols: frozenset[spux.IntSymbol] = frozenset() symbols: frozenset[spux.Symbol] = frozenset()
####################
# - Computed Properties
####################
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
@ -121,6 +115,19 @@ class LazyArrayRangeFlow:
""" """
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
@property
def is_symbolic(self) -> bool:
"""Whether the `RangeFlow` has unrealized symbols."""
return len(self.symbols) > 0
def __len__(self) -> int:
"""Compute the length of the array that would be realized.
Returns:
The number of steps.
"""
return self.steps
@functools.cached_property @functools.cached_property
def mathtype(self) -> spux.MathType: def mathtype(self) -> spux.MathType:
"""Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`. """Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`.
@ -156,128 +163,28 @@ class LazyArrayRangeFlow:
) )
return combined_mathtype return combined_mathtype
def __len__(self):
"""Compute the length of the array to be realized.
Returns:
The number of steps.
"""
return self.steps
#################### ####################
# - Units # - Methods
####################
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
"""Replaces the unit without rescaling the unitless bounds.
Parameters:
corrected_unit: The unit to replace the current unit with.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Corrected unit to %s',
self,
corrected_unit,
)
return LazyArrayRangeFlow(
start=self.start,
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to unit %s',
self,
unit,
)
return LazyArrayRangeFlow(
start=spux.scale_to_unit(self.start * self.unit, unit),
stop=spux.scale_to_unit(self.stop * self.unit, unit),
steps=self.steps,
scaling=self.scaling,
unit=unit,
symbols=self.symbols,
)
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
"""Replaces the units, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `LazyArrayRangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to new unit system (new unit = %s)',
self,
unit_system[spux.PhysicalType.from_unit(self.unit)],
)
return LazyArrayRangeFlow(
start=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
stop=spux.strip_unit_system(
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
unit_system,
),
steps=self.steps,
scaling=self.scaling,
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
symbols=self.symbols,
)
msg = (
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
)
raise ValueError(msg)
####################
# - Bound Operations
#################### ####################
def rescale( def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self: ) -> typ.Self:
"""Apply an order-preserving function to each bound, then (optionally) transform the result w/new unit and/or order.
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
Parameters:
rescale_func: An **order-preserving** function to apply to each array element.
reverse: Whether to reverse the order of the result.
new_unit: An (optional) new unit to scale the result to.
"""
new_pre_start = self.start if not reverse else self.stop new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit) new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit) new_stop = rescale_func(new_pre_stop * self.unit)
return LazyArrayRangeFlow( return RangeFlow(
start=( start=(
spux.scale_to_unit(new_start, new_unit) spux.scale_to_unit(new_start, new_unit)
if new_unit is not None if new_unit is not None
@ -294,39 +201,11 @@ class LazyArrayRangeFlow:
symbols=self.symbols, symbols=self.symbols,
) )
def rescale_bounds( def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
self, raise NotImplementedError
rescale_func: typ.Callable[
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
],
reverse: bool = False,
) -> typ.Self:
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters:
scaler: The function that scales each bound.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns:
A rescaled `LazyArrayRangeFlow`.
"""
return LazyArrayRangeFlow(
start=rescale_func(self.start if not reverse else self.stop),
stop=rescale_func(self.stop if not reverse else self.start),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
#################### ####################
# - Lazy Representation # - Exporters
#################### ####################
@functools.cached_property @functools.cached_property
def array_generator( def array_generator(
@ -345,10 +224,10 @@ class LazyArrayRangeFlow:
ScalingMode.Geom: jnp.geomspace, ScalingMode.Geom: jnp.geomspace,
ScalingMode.Log: jnp.logspace, ScalingMode.Log: jnp.logspace,
}.get(self.scaling) }.get(self.scaling)
if jnp_nspace is None: if jnp_nspace is None:
msg = f'ArrayFlow scaling method {self.scaling} is unsupported' msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
raise RuntimeError(msg) raise RuntimeError(msg)
return jnp_nspace return jnp_nspace
@functools.cached_property @functools.cached_property
@ -361,12 +240,12 @@ class LazyArrayRangeFlow:
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols. The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
Returns: Returns:
A `LazyValueFuncFlow` that, given the input symbols defined in `self.symbols`, A `FuncFlow` that, given the input symbols defined in `self.symbols`,
""" """
# Compile JAX Functions for Start/End Expressions # Compile JAX Functions for Start/End Expressions
## FYI, JAX-in-JAX works perfectly fine. ## -> FYI, JAX-in-JAX works perfectly fine.
start_jax = sp.lambdify(self.symbols, self.start, 'jax') start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax') stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
# Compile ArrayGen Function # Compile ArrayGen Function
def gen_array( def gen_array(
@ -378,18 +257,18 @@ class LazyArrayRangeFlow:
return gen_array return gen_array
@functools.cached_property @functools.cached_property
def as_lazy_value_func(self) -> LazyValueFuncFlow: def as_lazy_func(self) -> FuncFlow:
"""Creates a `LazyValueFuncFlow` using the output of `self.as_func`. """Creates a `FuncFlow` using the output of `self.as_func`.
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array. This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
Notes: Notes:
The the function enclosed in the `LazyValueFuncFlow` is identical to the one returned by `self.as_func`. The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
Returns: Returns:
A `LazyValueFuncFlow` containing `self.as_func`, as well as appropriate supporting settings. A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
""" """
return LazyValueFuncFlow( return FuncFlow(
func=self.as_func, func=self.as_func,
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols], func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
supports_jax=True, supports_jax=True,
@ -401,7 +280,8 @@ class LazyArrayRangeFlow:
def realize_start( def realize_start(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | LazyValueFuncFlow: ) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
return spux.sympy_to_python( return spux.sympy_to_python(
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols}) self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
) )
@ -409,7 +289,8 @@ class LazyArrayRangeFlow:
def realize_stop( def realize_stop(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | LazyValueFuncFlow: ) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
return spux.sympy_to_python( return spux.sympy_to_python(
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols}) self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
) )
@ -417,7 +298,11 @@ class LazyArrayRangeFlow:
def realize_step_size( def realize_step_size(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | LazyValueFuncFlow: ) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
raise NotImplementedError('Non-linear scaling mode not yet suported')
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer(): if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
@ -427,48 +312,34 @@ class LazyArrayRangeFlow:
def realize( def realize(
self, self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
kind: typ.Literal[FlowKind.Array, FlowKind.LazyValueFunc] = FlowKind.Array, ) -> ArrayFlow:
) -> ArrayFlow | LazyValueFuncFlow: """Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
"""Apply a function to the bounds, effectively rescaling the represented array.
Notes:
**It is presumed that the bounds are scaled with the same factor**.
Breaking this presumption may have unexpected results.
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
Parameters: Parameters:
scaler: The function that scales each bound. symbol_values: The particular values for each symbol, which will be inserted into the expression of each bound to realize them.
reverse: Whether to reverse the bounds after running the `scaler`.
Returns: Returns:
A rescaled `LazyArrayRangeFlow`. An `ArrayFlow` containing this realized `RangeFlow`.
""" """
if not set(self.symbols).issubset(set(symbol_values.keys())): ## TODO: Check symbol values for coverage.
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
raise ValueError(msg)
# Realize Symbols return ArrayFlow(
realized_start = self.realize_start(symbol_values) values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
realized_stop = self.realize_stop(symbol_values) unit=self.unit,
is_sorted=True,
# Return Linspace / Logspace )
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
return self.array_generator(realized_start, realized_stop, self.steps)
if kind == FlowKind.Array:
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
if kind == FlowKind.LazyValueFunc:
return LazyValueFuncFlow(func=gen_array, supports_jax=True)
msg = f'Invalid kind: {kind}'
raise TypeError(msg)
@functools.cached_property @functools.cached_property
def realize_array(self) -> ArrayFlow: def realize_array(self) -> ArrayFlow:
"""Standardized access to `self.realize()` when there are no symbols."""
return self.realize() return self.realize()
def __getitem__(self, subscript: slice): def __getitem__(self, subscript: slice):
"""Implement indexing and slicing in a sane way.
- **Integer Index**: Not yet implemented.
- **Slice**: Return the `RangeFlow` that creates the same `ArrayFlow` as would be created by computing `self.realize_array`, then slicing that.
"""
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin: if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice # Parse Slice
start = subscript.start if subscript.start is not None else 0 start = subscript.start if subscript.start is not None else 0
@ -482,7 +353,7 @@ class LazyArrayRangeFlow:
new_start = step_size * start new_start = step_size * start
new_stop = new_start + step_size * slice_steps new_stop = new_start + step_size * slice_steps
return LazyArrayRangeFlow( return RangeFlow(
start=sp.S(new_start), start=sp.S(new_start),
stop=sp.S(new_stop), stop=sp.S(new_stop),
steps=slice_steps, steps=slice_steps,
@ -492,3 +363,104 @@ class LazyArrayRangeFlow:
) )
raise NotImplementedError raise NotImplementedError
####################
# - Units
####################
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
"""Replaces the unit without rescaling the unitless bounds.
Parameters:
corrected_unit: The unit to replace the current unit with.
Returns:
A new `RangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Corrected unit to %s',
self,
corrected_unit,
)
return RangeFlow(
start=self.start,
stop=self.stop,
steps=self.steps,
scaling=self.scaling,
unit=corrected_unit,
symbols=self.symbols,
)
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
raise ValueError(msg)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `RangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to unit %s',
self,
unit,
)
return RangeFlow(
start=spux.scale_to_unit(self.start * self.unit, unit),
stop=spux.scale_to_unit(self.stop * self.unit, unit),
steps=self.steps,
scaling=self.scaling,
unit=unit,
symbols=self.symbols,
)
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg)
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
"""Replaces the units, **with** rescaling of the bounds.
Parameters:
unit: The unit to convert the bounds to.
Returns:
A new `RangeFlow` with replaced unit.
Raises:
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
"""
if self.unit is not None:
log.debug(
'%s: Scaled to new unit system (new unit = %s)',
self,
unit_system[spux.PhysicalType.from_unit(self.unit)],
)
return RangeFlow(
start=spux.strip_unit_system(
spux.convert_to_unit_system(self.start * self.unit, unit_system),
unit_system,
),
stop=spux.strip_unit_system(
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
unit_system,
),
steps=self.steps,
scaling=self.scaling,
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
symbols=self.symbols,
)
msg = (
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
)
raise ValueError(msg)

View File

@ -1,210 +0,0 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import typing as typ
from types import MappingProxyType
import jax
from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger
log = logger.get(__name__)
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class LazyValueFuncFlow:
r"""Wraps a composable function, providing useful information and operations.
# Data Flow as Function Composition
When using nodes to do math, it can be a good idea to express a **flow of data as the composition of functions**.
Each node creates a new function, which uses the still-unknown (aka. **lazy**) output of the previous function to plan some calculations.
Some new arguments may also be added, of course.
## Root Function
Of course, one needs to select a "bottom" function, which has no previous function as input.
Thus, the first step is to define this **root function**:
$$
f_0:\ \ \ \ \biggl(
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
\underbrace{
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
...,
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
}_{\texttt{kwargs}}
\biggr) \to \text{output}_0
$$
We'll express this simple snippet like so:
```python
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
## 'A0', 'KV0' are of length 'p' and 'q'
def f_0(*args, **kwargs): ...
lazy_value_func_0 = LazyValueFuncFlow(
func=f_0,
func_args=[(a_i, type(a_i)) for a_i in A0],
func_kwargs={k: v for k,v in KV0},
)
output_0 = lazy_value_func.func(*A0_computed, **KV0_computed)
```
So far so good.
But of course, nothing interesting has really happened yet.
## Composing Functions
The key thing is the next step: The function that uses the result of $f_0$!
$$
f_1:\ \ \ \ \biggl(
f_0(...),\ \
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
\underbrace{\biggl\{
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
\biggr) \to \text{output}_1
$$
Notice that _$f_1$ needs the arguments of both $f_0$ and $f_1$_.
Tracking arguments is already getting out of hand; we already have to use `...` to keep it readeable!
But doing so with `LazyValueFunc` is not so complex:
```python
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
## 'A1', 'KV1' are therefore of length 'r' and 's'
def f_1(output_0, *args, **kwargs): ...
lazy_value_func_1 = lazy_value_func_0.compose_within(
enclosing_func=f_1,
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
enclosing_func_kwargs={k: type(v) for k,v in K1},
)
A_computed = A0_computed + A1_computed
KW_computed = KV0_computed + KV1_computed
output_1 = lazy_value_func_1.func(*A_computed, **KW_computed)
```
We only need the arguments to $f_1$, and `LazyValueFunc` figures out how to make one function with enough arguments to call both.
## Isn't Laying Functions Slow/Hard?
Imagine that each function represents the action of a node, each of which performs expensive calculations on huge `numpy` arrays (**as one does when processing electromagnetic field data**).
At the end, a node might run the entire procedure with all arguments:
```python
output_n = lazy_value_func_n.func(*A_all, **KW_all)
```
It's rough: Most non-trivial pipelines drown in the time/memory overhead of incremental `numpy` operations - individually fast, but collectively iffy.
The killer feature of `LazyValueFuncFlow` is a sprinkle of black magic:
```python
func_n_jax = lazy_value_func_n.func_jax
output_n = func_n_jax(*A_all, **KW_all) ## Runs on your GPU
```
What happened was, **the entire pipeline** was compiled and optimized for high performance on not just your CPU, _but also (possibly) your GPU_.
All the layered function calls and inefficient incremental processing is **transformed into a high-performance program**.
Thank `jax` - specifically, `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), which internally enables this magic with a single function call.
## Other Considerations
**Auto-Differentiation**: Incredibly, `jax.jit` isn't the killer feature of `jax`. The function that comes out of `LazyValueFuncFlow` can also be differentiated with `jax.grad` (read: high-performance Jacobians for optimizing input parameters).
Though designed for machine learning, there's no reason other fields can't enjoy their inventions!
**Impact of Independent Caching**: JIT'ing can be slow.
That's why `LazyValueFuncFlow` has its own `FlowKind` "lane", which means that **only changes to the processing procedures will cause recompilation**.
Generally, adjustable values that affect the output will flow via the `Param` "lane", which has its own incremental caching, and only meets the compiled function when it's "plugged in" for final evaluation.
The effect is a feeling of snappiness and interactivity, even as the volume of data grows.
Attributes:
func: The function that the object encapsulates.
bound_args: Arguments that will be packaged into function, which can't be later modifier.
func_kwargs: Arguments to be specified by the user at the time of use.
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
"""
func: LazyFunction
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=list
)
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
default_factory=dict
)
supports_jax: bool = False
# Merging
def __or__(
self,
other: typ.Self,
):
return LazyValueFuncFlow(
func=lambda *args, **kwargs: (
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
other.func(
*list(args[len(self.func_args) :]),
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
)
# Composition
def compose_within(
self,
enclosing_func: LazyFunction,
enclosing_func_args: list[type] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
supports_jax: bool = False,
) -> typ.Self:
return LazyValueFuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
),
func_args=self.func_args + list(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
supports_jax=self.supports_jax and supports_jax,
)
@functools.cached_property
def func_jax(self) -> LazyFunction:
if self.supports_jax:
return jax.jit(self.func)
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
raise ValueError(msg)

View File

@ -22,7 +22,12 @@ from types import MappingProxyType
import sympy as sp import sympy as sp
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from blender_maxwell.utils import logger from blender_maxwell.utils import logger, sim_symbols
from .expr_info import ExprInfo
from .flow_kinds import FlowKind
# from .info import InfoFlow
log = logger.get(__name__) log = logger.get(__name__)
@ -32,7 +37,7 @@ class ParamsFlow:
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
symbols: frozenset[spux.Symbol] = frozenset() symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
@functools.cached_property @functools.cached_property
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
@ -44,14 +49,22 @@ class ParamsFlow:
return sorted(self.symbols, key=lambda sym: sym.name) return sorted(self.symbols, key=lambda sym: sym.name)
#################### ####################
# - Scaled Func Args # - Realize Arguments
#################### ####################
def scaled_func_args( def scaled_func_args(
self, self,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
): ):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" """Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
For all `arg`s in `self.func_args`, the following operations are performed.
Notes:
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
"""
if not all(sym in self.symbols for sym in symbol_values): if not all(sym in self.symbols for sym in symbol_values):
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
raise ValueError(msg) raise ValueError(msg)
@ -68,7 +81,7 @@ class ParamsFlow:
def scaled_func_kwargs( def scaled_func_kwargs(
self, self,
unit_system: spux.UnitSystem, unit_system: spux.UnitSystem | None = None,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
): ):
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
@ -92,7 +105,7 @@ class ParamsFlow:
): ):
"""Combine two function parameter lists, such that the LHS will be concatenated with the RHS. """Combine two function parameter lists, such that the LHS will be concatenated with the RHS.
Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters. Just like its neighbor in `Func`, this effectively combines two functions with unique parameters.
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur. The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
""" """
return ParamsFlow( return ParamsFlow(
@ -112,3 +125,61 @@ class ParamsFlow:
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols, symbols=self.symbols | enclosing_symbols,
) )
####################
# - Generate ExprSocketDef
####################
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
Notes:
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
```
self.loose_input_sockets = {
sym_name: sockets.ExprSocketDef(**expr_info)
for sym_name, expr_info in params.sym_expr_infos(info).items()
}
```
Parameters:
info: The InfoFlow associated with the `Expr` being realized.
Each symbol in `self.symbols` **must** have an associated same-named dimension in `info`.
use_range: Causes the
The `ExprInfo`s can be directly defererenced `**expr_info`)
"""
for sim_sym in self.sorted_symbols:
if use_range and sim_sym.mathtype is spux.MathType.Complex:
msg = 'No support for complex range in ExprInfo'
raise NotImplementedError(msg)
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
msg = 'No support for non-scalar elements of range in ExprInfo'
raise NotImplementedError(msg)
if sim_sym.rows > 3 or sim_sym.cols > 1:
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
raise NotImplementedError(msg)
return {
sim_sym.name: {
# Declare Kind/Size
## -> Kind: Value prevents user-alteration of config.
## -> Size: Always scalar, since symbols are scalar (for now).
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
'size': spux.NumberSize1D.Scalar,
# Declare MathType/PhysicalType
## -> MathType: Lookup symbol name in info dimensions.
## -> PhysicalType: Same.
'mathtype': self.dims[sim_sym].mathtype,
'physical_type': self.dims[sim_sym].physical_type,
# TODO: Default Value
# FlowKind.Value: Default Value
#'default_value':
# FlowKind.Range: Default Min/Max/Steps
'default_min': sim_sym.domain.start,
'default_max': sim_sym.domain.end,
'default_steps': 50,
}
for sim_sym in self.sorted_symbols
}

View File

@ -50,6 +50,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
# Outputs # Outputs
Viewer = enum.auto() Viewer = enum.auto()
## Outputs / File Exporters ## Outputs / File Exporters
DataFileExporter = enum.auto()
Tidy3DWebExporter = enum.auto() Tidy3DWebExporter = enum.auto()
## Outputs / Web Exporters ## Outputs / Web Exporters
JSONFileExporter = enum.auto() JSONFileExporter = enum.auto()

View File

@ -19,13 +19,21 @@
import dataclasses import dataclasses
import enum import enum
import typing as typ import typing as typ
from pathlib import Path
import jax.numpy as jnp import jax.numpy as jnp
import sympy as sp import jaxtyping as jtyp
import numpy as np
import polars as pl
import tidy3d as td import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud from blender_maxwell.services import tdcloud
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow
log = logger.get(__name__)
#################### ####################
@ -293,3 +301,321 @@ class NewSimCloudTask:
task_name: tdcloud.CloudTaskName task_name: tdcloud.CloudTaskName
cloud_folder: tdcloud.CloudFolder cloud_folder: tdcloud.CloudFolder
####################
# - Data File
####################
_DATA_FILE_EXTS = {
'.txt',
'.txt.gz',
'.csv',
'.npy',
}
class DataFileFormat(enum.StrEnum):
"""Abstraction of a data file format, providing a regularized way of interacting with filesystem data.
Import/export interacts closely with the `Expr` socket's `FlowKind` semantics:
- `FlowKind.Func`: Generally realized on-import/export.
- **Import**: Loading data is generally eager, but memory-mapped file loading would be manageable using this interface.
- **Export**: The function is realized and only the array is inserted into the file.
- `FlowKind.Params`: Generally consumed.
- **Import**: A new, empty `ParamsFlow` object is created.
- **Export**: The `ParamsFlow` is consumed when realizing the `Func`.
- `FlowKind.Info`: As the most important element, it is kept in an (optional) sidecar metadata file.
- **Import**: The sidecar file is loaded, checked, and used, if it exists. A warning about further processing may show if it doesn't.
- **Export**: The sidecar file is written next to the canonical data file, in such a manner that it can be both read and loaded.
Notes:
This enum is UI Compatible, ex. for nodes/sockets desiring a dropdown menu of data file formats.
Attributes:
Txt: Simple no-header text file.
Only supports 1D/2D data.
TxtGz: Identical to `Txt`, but compressed with `gzip`.
Csv: Unspecific "Comma Separated Values".
For loading, `pandas`-default semantics are used.
For saving, very opinionated defaults are used.
Customization is disabled on purpose.
Npy: Generic numpy representation.
Supports all kinds of numpy objects.
Better laziness support via `jax`.
"""
Csv = enum.auto()
Npy = enum.auto()
Txt = enum.auto()
TxtGz = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(v: typ.Self) -> str:
"""The extension name of the given `DataFileFormat`.
Notes:
Called by the UI when creating an `EnumProperty` dropdown.
"""
return DataFileFormat(v).extension
@staticmethod
def to_icon(v: typ.Self) -> str:
"""No icon.
Notes:
Called by the UI when creating an `EnumProperty` dropdown.
"""
return ''
def bl_enum_element(self, i: int) -> BLEnumElement:
"""Produce a fully functional Blender enum element, given a particular integer index."""
return (
str(self),
DataFileFormat.to_name(self),
DataFileFormat.to_name(self),
DataFileFormat.to_icon(self),
i,
)
@staticmethod
def bl_enum_elements() -> list[BLEnumElement]:
"""Produce an immediately usable list of Blender enum elements, correctly indexed."""
return [
data_file_format.bl_enum_element(i)
for i, data_file_format in enumerate(list(DataFileFormat))
]
####################
# - Properties
####################
@property
def extension(self) -> str:
"""Map to the actual string extension."""
E = DataFileFormat
return {
E.Csv: '.csv',
E.Npy: '.npy',
E.Txt: '.txt',
E.TxtGz: '.txt.gz',
}[self]
####################
# - Creation: Compatibility
####################
@staticmethod
def valid_exts() -> list[str]:
return _DATA_FILE_EXTS
@staticmethod
def ext_has_valid_format(ext: str) -> bool:
return ext in _DATA_FILE_EXTS
@staticmethod
def path_has_valid_format(path: Path) -> bool:
return path.is_file() and DataFileFormat.ext_has_valid_format(
''.join(path.suffixes)
)
def is_path_compatible(
self, path: Path, must_exist: bool = False, can_exist: bool = True
) -> bool:
ext_matches = self.extension == ''.join(path.suffixes)
match (must_exist, can_exist):
case (False, False):
return ext_matches and not path.is_file() and path.parent.is_dir()
case (True, False):
msg = f'DataFileFormat: Path {path} cannot both be required to exist (must_exist=True), but also not be allowed to exist (can_exist=False)'
raise ValueError(msg)
case (False, True):
return ext_matches and path.parent.is_dir()
case (True, True):
return ext_matches and path.is_file()
####################
# - Creation
####################
@staticmethod
def from_ext(ext: str) -> typ.Self | None:
return {
_ext: _data_file_ext
for _data_file_ext, _ext in {
k: k.extension for k in list(DataFileFormat)
}.items()
}.get(ext)
@staticmethod
def from_path(path: Path) -> typ.Self | None:
if DataFileFormat.path_has_valid_format(path):
data_file_ext = DataFileFormat.from_ext(''.join(path.suffixes))
if data_file_ext is not None:
return data_file_ext
msg = f'DataFileFormat: Path "{path}" is compatible, but could not find valid extension'
raise RuntimeError(msg)
return None
####################
# - Functions: Metadata
####################
def supports_metadata(self) -> bool:
E = DataFileFormat
return {
E.Csv: False, ## No RFC 4180 Support for Comments
E.Npy: False, ## Quite simply no support
E.Txt: True, ## Use # Comments
E.TxtGz: True, ## Same as Txt
}[self]
## TODO: Sidecar Metadata
## - The vision is that 'saver' also writes metadata.
## - This metadata is essentially a straight serialization of the InfoFlow.
## - On-load, the metadata is used to re-generate the InfoFlow.
## - This allows interpreting saved data without a ton of shenanigans.
## - These sidecars could also be hand-writable for external data.
## - When sidecars aren't found, the user would "fill in the blanks".
## - ...Thus achieving the same result as if there were a sidecar.
####################
# - Functions: Saver
####################
def is_info_compatible(self, info: InfoFlow) -> bool:
E = DataFileFormat
match self:
case E.Csv:
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
case E.Npy:
return True
case E.Txt | E.TxtGz:
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
@property
def saver(
self,
) -> typ.Callable[[Path, jtyp.Shaped[jtyp.Array, '...'], InfoFlow], None]:
def save_txt(path, data, info):
np.savetxt(path, data)
def save_txt_gz(path, data, info):
np.savetxt(path, data)
def save_csv(path, data, info):
data_np = np.array(data)
# Extract Input Coordinates
dim_columns = {
dim.name: np.array(dim_idx.realize_array)
for i, (dim, dim_idx) in enumerate(info.dims)
} ## TODO: realize_array might not be defined on some index arrays
# Declare Function to Extract Output Values
output_columns = {}
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
nonlocal output_columns
# Complex: Split to Two Columns
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
if bool(np.any(np.iscomplex(data_col))):
output_columns |= {
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
}
# Else: Use Array Directly
else:
output_columns |= {
f'{info.output.name}{output_idx_str}': data_col,
}
## TODO: Maybe a check to ensure dtype!=object?
# Extract Output Values
## -> 2D: Iterate over columns by-index.
## -> 1D: Declare the array as the only column.
if len(data_np.shape) == 2:
for output_idx in data_np.shape[1]:
declare_output_col(data_np[:, output_idx], output_idx, True)
else:
declare_output_col(data_np)
# Compute DataFrame & Write CSV
df = pl.DataFrame(dim_columns | output_columns)
log.debug('Writing Polars DataFrame to CSV:')
log.debug(df)
df.write_csv(path)
def save_npy(path, data, info):
jnp.save(path, data)
E = DataFileFormat
return {
E.Csv: save_csv,
E.Npy: save_npy,
E.Txt: save_txt,
E.TxtGz: save_txt_gz,
}[self]
####################
# - Functions: Loader
####################
@property
def loader_is_jax_compatible(self) -> bool:
E = DataFileFormat
return {
E.Csv: False,
E.Npy: True,
E.Txt: True,
E.TxtGz: True,
}[self]
@property
def loader(
self,
) -> typ.Callable[[Path], tuple[jtyp.Shaped[jtyp.Array, '...'], InfoFlow]]:
def load_txt(path: Path):
return jnp.asarray(np.loadtxt(path))
def load_csv(path: Path):
return jnp.asarray(pl.read_csv(path).to_numpy())
## TODO: The very next Polars (0.20.27) has a '.to_jax' method!
def load_npy(path: Path):
return jnp.load(path)
E = DataFileFormat
return {
E.Csv: load_csv,
E.Npy: load_npy,
E.Txt: load_txt,
E.TxtGz: load_txt,
}[self]
####################
# - Metadata: Compatibility
####################
def is_info_compatible(self, info: InfoFlow) -> bool:
E = DataFileFormat
match self:
case E.Csv:
return len(info.dims) + (info.output.rows + input.outputs.cols - 1) <= 2
case E.Npy:
return True
case E.Txt | E.TxtGz:
return len(info.dims) + (info.output.rows + info.output.cols - 1) <= 2
def supports_metadata(self) -> bool:
E = DataFileFormat
return {
E.Csv: False, ## No RFC 4180 Support for Comments
E.Npy: False, ## Quite simply no support
E.Txt: True, ## Use # Comments
E.TxtGz: True, ## Same as Txt
}[self]

View File

@ -48,7 +48,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics # Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2, _PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um, _PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um, _PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um, _PT.HField: spu.ampere / spu.um,
# Mechanical # Mechanical
@ -58,7 +57,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
_PT.Force: spux.micronewton, _PT.Force: spux.micronewton,
# Luminal # Luminal
# Optics # Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
} ## TODO: Load (dynamically?) from addon preferences } ## TODO: Load (dynamically?) from addon preferences
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | { UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
@ -75,11 +73,9 @@ UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
# Electrodynamics # Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2, _PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um, _PT.Conductivity: spu.siemens / spu.um,
_PT.PoyntingVector: spu.watt / spu.um**2,
_PT.EField: spu.volt / spu.um, _PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um, _PT.HField: spu.ampere / spu.um,
# Luminal # Luminal
# Optics # Optics
_PT.PoyntingVector: spu.watt / spu.um**2,
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
} }

View File

@ -17,15 +17,15 @@
"""Implements `ExtractDataNode`.""" """Implements `ExtractDataNode`."""
import enum import enum
import functools
import typing as typ import typing as typ
import bpy import bpy
import jax import jax.numpy as jnp
import numpy as np
import sympy.physics.units as spu import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from ... import contracts as ct from ... import contracts as ct
@ -37,6 +37,176 @@ log = logger.get(__name__)
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
####################
# - Monitor Label Arrays
####################
def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
"""
monitor_data = sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
case 'Permittivity':
return ['eps_xx', 'eps_yy', 'eps_zz']
case 'Flux' | 'FluxTime':
return ['flux']
case (
'FieldProjectionAngle'
| 'FieldProjectionCartesian'
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
"""Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
xarr = getattr(monitor_data, monitor_attr, None)
if xarr is None:
return None
def mk_idx_array(axis: str) -> ct.ArrayFlow:
return ct.ArrayFlow(
values=xarr.get_index(axis).values,
unit=symbols[axis].unit,
is_sorted=True,
)
# Compute InfoFlow from XArray
symbols = {
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
'z': sim_symbols.space_z(spu.micrometer),
# Spherical
'r': sim_symbols.ang_r(spu.micrometer),
'theta': sim_symbols.ang_theta(spu.radian),
'phi': sim_symbols.ang_phi(spu.radian),
# Freq|Time
'f': sim_symbols.freq(spu.hertz),
't': sim_symbols.t(spu.second),
# Power Flux
'flux': sim_symbols.flux(spu.watt),
# Cartesian Fields
'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Spherical Fields
'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Wavevector
'ux': sim_symbols.dir_x(spu.watt),
'uy': sim_symbols.dir_y(spu.watt),
# Diffraction Orders
'orders_x': sim_symbols.diff_order_x(None),
'orders_y': sim_symbols.diff_order_y(None),
}
match monitor_data.type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldTime':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'Flux':
return ct.InfoFlow(
dims={
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FluxTime':
return ct.InfoFlow(
dims={
symbols['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionAngle':
return ct.InfoFlow(
dims={
symbols['r']: mk_idx_array('r'),
symbols['theta']: mk_idx_array('theta'),
symbols['phi']: mk_idx_array('phi'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'FieldProjectionKSpace':
return ct.InfoFlow(
dims={
symbols['ux']: mk_idx_array('ux'),
symbols['uy']: mk_idx_array('uy'),
symbols['r']: mk_idx_array('r'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
case 'Diffraction':
return ct.InfoFlow(
dims={
symbols['orders_x']: mk_idx_array('orders_x'),
symbols['orders_y']: mk_idx_array('orders_y'),
symbols['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
)
return None
####################
# - Node
####################
class ExtractDataNode(base.MaxwellSimNode): class ExtractDataNode(base.MaxwellSimNode):
"""Extract data from sockets for further analysis. """Extract data from sockets for further analysis.
@ -45,33 +215,21 @@ class ExtractDataNode(base.MaxwellSimNode):
Monitor Data: Extract `Expr`s from monitor data by-component. Monitor Data: Extract `Expr`s from monitor data by-component.
Attributes: Attributes:
extract_filter: Identifier for data to extract from the input. monitor_attr: Identifier for data to extract from the input.
""" """
node_type = ct.NodeType.ExtractData node_type = ct.NodeType.ExtractData
bl_label = 'Extract' bl_label = 'Extract'
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()}, 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
} }
output_socket_sets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Monitor Data': {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc)
},
} }
#################### ####################
# - Properties # - Properties: Monitor Name
####################
extract_filter: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_extract_filters(),
cb_depends_on={'sim_data_monitor_nametype', 'monitor_data_type'},
)
####################
# - Computed: Sim Data
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name='Sim Data', socket_name='Sim Data',
@ -101,198 +259,49 @@ class ExtractDataNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'sim_data'}) @bl_cache.cached_bl_property(depends_on={'sim_data'})
def sim_data_monitor_nametype(self) -> dict[str, str] | None: def sim_data_monitor_nametype(self) -> dict[str, str] | None:
"""For simulation data, deduces a map from the monitor name to the monitor "type". """Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
Return: Return:
The name to type of monitors in the simulation data. The name to type of monitors in the simulation data.
""" """
if self.sim_data is not None: if self.sim_data is not None:
return { return {
monitor_name: monitor_data.type monitor_name: monitor_data.type.removesuffix('Data')
for monitor_name, monitor_data in self.sim_data.monitor_data.items() for monitor_name, monitor_data in self.sim_data.monitor_data.items()
} }
return None return None
#################### monitor_name: enum.StrEnum = bl_cache.BLField(
# - Computed Properties: Monitor Data enum_cb=lambda self, _: self.search_monitor_names(),
#################### cb_depends_on={'sim_data_monitor_nametype'},
@events.on_value_changed(
socket_name='Monitor Data',
input_sockets={'Monitor Data'},
input_sockets_optional={'Monitor Data': True},
) )
def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data'])
if has_monitor_data:
self.monitor_data = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property() def search_monitor_names(self) -> list[ct.BLEnumElement]:
def monitor_data(self) -> TDMonitorData | None: """Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
"""Extracts the monitor data from the input socket.
Return:
Either the monitor data, if available, or None.
"""
monitor_data = self._compute_input(
'Monitor Data', kind=ct.FlowKind.Value, optional=True
)
has_monitor_data = not ct.FlowSignal.check(monitor_data)
if has_monitor_data:
return monitor_data
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data'})
def monitor_data_type(self) -> str | None:
r"""For monitor data, deduces the monitor "type".
- **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain.
- **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor.
- **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain.
For planes, an explicit direction is defined.
For volumes, the the integral of all outgoing energy is stored.
- **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection.
- **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order.
Notes: Notes:
Should be invalidated with (before) `self.monitor_data_attrs`. Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
return self.monitor_data.type.removesuffix('Data')
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data_type'})
def monitor_data_attrs(self) -> list[str] | None:
r"""For monitor data, deduces the valid data-containing attributes.
The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups.
- **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor.
- **Permittivity**: Specifically `['xx', 'yy', 'zz']`.
- **Flux(Time)**: Only `['flux']`.
- **FieldProjection(...)**: All of $r$, $\theta$, $\phi$ for both `E` and `H`.
- **Diffraction**: Same as `FieldProjection`.
Notes:
Should be invalidated after with `self.monitor_data_type`.
Return:
The "type" of the monitor, if available, else None.
"""
if self.monitor_data is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if hasattr(self.monitor_data, field_component)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return ['xx', 'yy', 'zz']
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return ['flux']
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
return None
####################
# - Extraction Filter Search
####################
def search_extract_filters(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`.
Notes:
Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`. See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns: Returns:
Valid `self.extract_filter` in a format compatible with dynamic `EnumProperty`. Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
""" """
if self.sim_data_monitor_nametype is not None: if self.sim_data_monitor_nametype is not None:
return [ return [
(monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i) (
monitor_name,
monitor_name,
monitor_type + ' Monitor Data',
'',
i,
)
for i, (monitor_name, monitor_type) in enumerate( for i, (monitor_name, monitor_type) in enumerate(
self.sim_data_monitor_nametype.items() self.sim_data_monitor_nametype.items()
) )
] ]
if self.monitor_data_attrs is not None:
# Field/FieldTime
if self.monitor_data_type in ['Field', 'FieldTime']:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Permittivity
if self.monitor_data_type == 'Permittivity':
return [
(monitor_attr, monitor_attr, f' ε_{monitor_attr}', '', i)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# Flux/FluxTime
if self.monitor_data_type in ['Flux', 'FluxTime']:
return [
(
monitor_attr,
monitor_attr,
'Power flux integral through the plane / out of the volume',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
if self.monitor_data_type in [
'FieldProjectionAngle',
'FieldProjectionCartesian',
'FieldProjectionKSpace',
'Diffraction',
]:
return [
(
monitor_attr,
monitor_attr,
f' {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
'',
i,
)
for i, monitor_attr in enumerate(self.monitor_data_attrs)
]
return [] return []
#################### ####################
@ -305,10 +314,9 @@ class ExtractDataNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header. Called by Blender to determine the text to place in the node's header.
""" """
has_sim_data = self.sim_data_monitor_nametype is not None has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_attrs is not None
if has_sim_data or has_monitor_data: if has_sim_data:
return f'Extract: {self.extract_filter}' return f'Extract: {self.monitor_name}'
return self.bl_label return self.bl_label
@ -318,340 +326,115 @@ class ExtractDataNode(base.MaxwellSimNode):
Parameters: Parameters:
col: UI target for drawing. col: UI target for drawing.
""" """
col.prop(self, self.blfields['extract_filter'], text='') col.prop(self, self.blfields['monitor_name'], text='')
#################### ####################
# - FlowKind.Value: Sim Data -> Monitor Data # - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor Data',
kind=ct.FlowKind.Value,
# Loaded
props={'extract_filter'},
input_sockets={'Sim Data'},
input_sockets_optional={'Sim Data': True},
)
def compute_monitor_data(
self, props: dict, input_sockets: dict
) -> TDMonitorData | ct.FlowSignal:
"""Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`.
Returns:
Monitor data, if available, else `ct.FlowSignal.FlowPending`.
"""
extract_filter = props['extract_filter']
sim_data = input_sockets['Sim Data']
has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data and extract_filter is not None:
return sim_data.monitor_data[extract_filter]
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Array|LazyValueFunc: Monitor Data -> Expr
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Array, kind=ct.FlowKind.Func,
# Loaded # Loaded
props={'extract_filter'}, props={'monitor_name'},
input_sockets={'Monitor Data'}, input_sockets={'Sim Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, input_socket_kinds={'Sim Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
) )
def compute_expr( def compute_expr(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> jax.Array | ct.FlowSignal: ) -> ct.FuncFlow | ct.FlowSignal:
"""Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it. sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
Uses the internal `xarray` data returned by Tidy3D. has_sim_data = not ct.FlowSignal.check(sim_data)
By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy.
Returns: if has_sim_data and monitor_name is not None:
The data array, if available, else `ct.FlowSignal.FlowPending`. monitor_data = sim_data.get(monitor_name)
""" if monitor_data is not None:
extract_filter = props['extract_filter'] # Extract Valid Index Labels
monitor_data = input_sockets['Monitor Data'] ## -> The first output axis will be integer-indexed.
has_monitor_data = not ct.FlowSignal.check(monitor_data) ## -> Each integer will have a string label.
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
if has_monitor_data and extract_filter is not None: # Generate FuncFlow Per Index Label
xarray_data = getattr(monitor_data, extract_filter) ## -> We extract each XArray as an attribute of monitor_data.
return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None) ## -> We then bind its values into a unique func_flow.
## -> This lets us 'stack' then all along the first axis.
return ct.FlowSignal.FlowPending func_flows = []
for idx_label in idx_labels:
@events.computes_output_socket( xarr = getattr(monitor_data, idx_label)
# Trigger func_flows.append(
'Expr', ct.FuncFlow(
kind=ct.FlowKind.LazyValueFunc, func=lambda xarr=xarr: xarr.values,
# Loaded supports_jax=True,
output_sockets={'Expr'}, )
output_socket_kinds={'Expr': ct.FlowKind.Array}, )
output_sockets_optional={'Expr': True},
)
def compute_extracted_data_lazy(
self, output_sockets: dict
) -> ct.LazyValueFuncFlow | None:
"""Declare `Expr:LazyValueFunc` by creating a simple function that directly wraps `Expr:Array`.
Returns:
The composable function array, if available, else `ct.FlowSignal.FlowPending`.
"""
output_expr = output_sockets['Expr']
has_output_expr = not ct.FlowSignal.check(output_expr)
if has_output_expr:
return ct.LazyValueFuncFlow(
func=lambda: output_expr.values, supports_jax=True
)
# Concatenate and Stack Unified FuncFlow
## -> First, 'reduce' lets us __or__ all the FuncFlows together.
## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
enclosing_func=lambda data: jnp.stack(data, axis=0)
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind.Params: Monitor Data -> Expr # - FlowKind.Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Params},
) )
def compute_data_params(self) -> ct.ParamsFlow: def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline. """Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
Returns: Returns:
A completely empty `ParamsFlow`, ready to be composed. A completely empty `ParamsFlow`, ready to be composed.
""" """
sim_params = input_sockets['Sim Data']
has_sim_params = not ct.FlowSignal.check(sim_params)
if has_sim_params:
return sim_params
return ct.ParamsFlow() return ct.ParamsFlow()
#################### ####################
# - FlowKind.Info: Monitor Data -> Expr # - FlowKind.Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
# Loaded # Loaded
props={'monitor_data_type', 'extract_filter'}, props={'monitor_name'},
input_sockets={'Monitor Data'}, input_sockets={'Sim Data'},
input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, input_socket_kinds={'Sim Data': ct.FlowKind.Value},
input_sockets_optional={'Monitor Data': True},
) )
def compute_extracted_data_info( def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
self, props: dict, input_sockets: dict
) -> ct.InfoFlow:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type. """Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
Returns: Returns:
Information describing the `Data:LazyValueFunc`, if available, else `ct.FlowSignal.FlowPending`. Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
""" """
monitor_data = input_sockets['Monitor Data'] sim_data = input_sockets['Sim Data']
monitor_data_type = props['monitor_data_type'] monitor_name = props['monitor_name']
extract_filter = props['extract_filter']
has_monitor_data = not ct.FlowSignal.check(monitor_data) has_sim_data = not ct.FlowSignal.check(sim_data)
# Edge Case: Dangling 'flux' Access on 'FieldMonitor' if not has_sim_data or monitor_name is None:
## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields.
## -> We don't allow 'flux' attribute access, but it can dangle.
## -> (The method is called when updating each depschain component.)
if monitor_data_type == 'Field' and extract_filter == 'flux':
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Retrieve XArray # Extract Data
if has_monitor_data and extract_filter is not None: ## -> All monitor_data.<idx_label> have the exact same InfoFlow.
xarr = getattr(monitor_data, extract_filter, None) ## -> So, just construct an InfoFlow w/prepended labelled dimension.
if xarr is None: monitor_data = sim_data.get(monitor_name)
return ct.FlowSignal.FlowPending idx_labels = valid_monitor_attrs(sim_data, monitor_name)
else: info = extract_info(monitor_data, idx_labels[0])
return ct.FlowSignal.FlowPending
# Compute InfoFlow from XArray return info.prepend_dim(sim_symbols.idx, idx_labels)
## XYZF: Field / Permittivity / FieldProjectionCartesian
if monitor_data_type in {
'Field',
'Permittivity',
#'FieldProjectionCartesian',
}:
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 'f'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## XYZT: FieldTime
if monitor_data_type == 'FieldTime':
return ct.InfoFlow(
dim_names=['x', 'y', 'z', 't'],
dim_idx={
axis: ct.ArrayFlow(
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
)
for axis in ['x', 'y', 'z']
}
| {
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.second,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
),
)
## F: Flux
if monitor_data_type == 'Flux':
return ct.InfoFlow(
dim_names=['f'],
dim_idx={
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## T: FluxTime
if monitor_data_type == 'FluxTime':
return ct.InfoFlow(
dim_names=['t'],
dim_idx={
't': ct.ArrayFlow(
values=xarr.get_index('t').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## RThetaPhiF: FieldProjectionAngle
if monitor_data_type == 'FieldProjectionAngle':
return ct.InfoFlow(
dim_names=['r', 'theta', 'phi', 'f'],
dim_idx={
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
}
| {
c: ct.ArrayFlow(
values=xarr.get_index(c).values,
unit=spu.radian,
is_sorted=True,
)
for c in ['r', 'theta', 'phi']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## UxUyRF: FieldProjectionKSpace
if monitor_data_type == 'FieldProjectionKSpace':
return ct.InfoFlow(
dim_names=['ux', 'uy', 'r', 'f'],
dim_idx={
c: ct.ArrayFlow(
values=xarr.get_index(c).values, unit=None, is_sorted=True
)
for c in ['ux', 'uy']
}
| {
'r': ct.ArrayFlow(
values=xarr.get_index('r').values,
unit=spu.micrometer,
is_sorted=True,
),
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
## OrderxOrderyF: Diffraction
if monitor_data_type == 'Diffraction':
return ct.InfoFlow(
dim_names=['orders_x', 'orders_y', 'f'],
dim_idx={
f'orders_{c}': ct.ArrayFlow(
values=xarr.get_index(f'orders_{c}').values,
unit=None,
is_sorted=True,
)
for c in ['x', 'y']
}
| {
'f': ct.ArrayFlow(
values=xarr.get_index('f').values,
unit=spu.hertz,
is_sorted=True,
),
},
output_name=extract_filter,
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if extract_filter.startswith('E')
else spu.ampere / spu.micrometer
),
)
msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"'
raise RuntimeError(msg)
#################### ####################

View File

@ -98,29 +98,29 @@ class FilterOperation(enum.StrEnum):
operations = [] operations = []
# Slice # Slice
if info.dim_names: if info.dims:
operations.append(FO.SliceIdx) operations.append(FO.SliceIdx)
# Pin # Pin
## PinLen1 ## PinLen1
## -> There must be a dimension with length 1. ## -> There must be a dimension with length 1.
if 1 in list(info.dim_lens.values()): if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1) operations.append(FO.PinLen1)
## Pin | PinIdx ## Pin | PinIdx
## -> There must be a dimension, full stop. ## -> There must be a dimension, full stop.
if info.dim_names: if info.dims:
operations += [FO.Pin, FO.PinIdx] operations += [FO.Pin, FO.PinIdx]
# Reinterpret # Reinterpret
## Swap ## Swap
## -> There must be at least two dimensions. ## -> There must be at least two dimensions.
if len(info.dim_names) >= 2: # noqa: PLR2004 if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap) operations.append(FO.Swap)
## SetDim ## SetDim
## -> There must be a dimension to correct. ## -> There must be a dimension to correct.
if info.dim_names: if info.dims:
operations.append(FO.SetDim) operations.append(FO.SetDim)
return operations return operations
@ -158,33 +158,33 @@ class FilterOperation(enum.StrEnum):
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation FO = FilterOperation
match self: match self:
case FO.SliceIdx: case FO.SliceIdx | FO.Swap:
return info.dim_names return info.dims
# PinLen1: Only allow dimensions with length=1. # PinLen1: Only allow dimensions with length=1.
case FO.PinLen1: case FO.PinLen1:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims.items()
if info.dim_lens[dim_name] == 1 if dim_idx is not None and len(dim_idx) == 1
] ]
# Pin: Only allow dimensions with known indexing. # Pin: Only allow dimensions with discrete index.
case FO.Pin: ## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
case FO.Pin | FO.PinIdx:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims
if info.dim_has_coords[dim_name] != 0 if dim_idx is not None and len(dim_idx) > 0
] ]
case FO.PinIdx | FO.Swap:
return info.dim_names
case FO.SetDim: case FO.SetDim:
return [ return [
dim_name dim
for dim_name in info.dim_names for dim, dim_idx in info.dims
if info.dim_mathtypes[dim_name] == spux.MathType.Integer if dim_idx is not None
and not isinstance(dim_idx, list)
and dim_idx.mathtype == spux.MathType.Integer
] ]
return [] return []
@ -224,22 +224,22 @@ class FilterOperation(enum.StrEnum):
def transform_info( def transform_info(
self, self,
info: ct.InfoFlow, info: ct.InfoFlow,
dim_0: str, dim_0: sim_symbols.SimSymbol,
dim_1: str, dim_1: sim_symbols.SimSymbol,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None, slice_tuple: tuple[int, int, int] | None = None,
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.LazyArrayRangeFlow]] replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
| None = None,
): ):
FO = FilterOperation FO = FilterOperation
return { return {
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin # Pin
FO.PinLen1: lambda: info.delete_dimension(dim_0), FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.Pin: lambda: info.delete_dimension(dim_0), FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dimension(dim_0), FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret # Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*corrected_dim), FO.SetDim: lambda: info.replace_dim(*replaced_dim),
}[self]() }[self]()
@ -265,10 +265,10 @@ class FilterMathNode(base.MaxwellSimNode):
bl_label = 'Filter Math' bl_label = 'Filter Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
#################### ####################
@ -318,11 +318,11 @@ class FilterMathNode(base.MaxwellSimNode):
#################### ####################
# - Properties: Dimension Selection # - Properties: Dimension Selection
#################### ####################
dim_0: enum.StrEnum = bl_cache.BLField( active_dim_0: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(), enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'}, cb_depends_on={'operation', 'expr_info'},
) )
dim_1: enum.StrEnum = bl_cache.BLField( active_dim_1: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(), enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'}, cb_depends_on={'operation', 'expr_info'},
) )
@ -335,40 +335,23 @@ class FilterMathNode(base.MaxwellSimNode):
] ]
return [] return []
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
def dim_0(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_0 is not None:
return self.expr_info.dim_by_name(self.active_dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
def dim_1(self) -> sim_symbols.SimSymbol | None:
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_by_name(self.active_dim_1)
return None
#################### ####################
# - Properties: Slice # - Properties: Slice
#################### ####################
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1]) slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
####################
# - Properties: Unit
####################
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
sim_symbols.CommonSimSymbol.X
)
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'set_dim_symbol'},
)
def search_valid_units(self) -> list[ct.BLEnumElement]:
"""Compute Blender enum elements of valid units for the current `physical_type`."""
physical_type = self.set_dim_symbol.sim_symbol.physical_type
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
def set_dim_unit(self) -> spux.Unit | None:
if self.set_dim_active_unit is not None:
return spux.unit_str_to_unit(self.set_dim_active_unit)
return None
#################### ####################
# - UI # - UI
#################### ####################
@ -378,27 +361,27 @@ class FilterMathNode(base.MaxwellSimNode):
# Slice # Slice
case FO.SliceIdx: case FO.SliceIdx:
slice_str = ':'.join([str(v) for v in self.slice_tuple]) slice_str = ':'.join([str(v) for v in self.slice_tuple])
return f'Filter: {self.dim_0}[{slice_str}]' return f'Filter: {self.active_dim_0}[{slice_str}]'
# Pin # Pin
case FO.PinLen1: case FO.PinLen1:
return f'Filter: Pin {self.dim_0}[0]' return f'Filter: Pin {self.active_dim_0}[0]'
case FO.Pin: case FO.Pin:
return f'Filter: Pin {self.dim_0}[...]' return f'Filter: Pin {self.active_dim_0}[...]'
case FO.PinIdx: case FO.PinIdx:
pin_idx_axis = self._compute_input( pin_idx_axis = self._compute_input(
'Axis', kind=ct.FlowKind.Value, optional=True 'Axis', kind=ct.FlowKind.Value, optional=True
) )
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis) has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
if has_pin_idx_axis: if has_pin_idx_axis:
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]' return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
return self.bl_label return self.bl_label
# Reinterpret # Reinterpret
case FO.Swap: case FO.Swap:
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]' return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
case FO.SetDim: case FO.SetDim:
return f'Filter: Set [{self.dim_0}]' return f'Filter: Set [{self.active_dim_0}]'
case _: case _:
return self.bl_label return self.bl_label
@ -409,20 +392,15 @@ class FilterMathNode(base.MaxwellSimNode):
if self.operation is not None: if self.operation is not None:
match self.operation.num_dim_inputs: match self.operation.num_dim_inputs:
case 1: case 1:
layout.prop(self, self.blfields['dim_0'], text='') layout.prop(self, self.blfields['active_dim_0'], text='')
case 2: case 2:
row = layout.row(align=True) row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='') row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='') row.prop(self, self.blfields['active_dim_1'], text='')
if self.operation is FilterOperation.SliceIdx: if self.operation is FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='') layout.prop(self, self.blfields['slice_tuple'], text='')
if self.operation is FilterOperation.SetDim:
row = layout.row(align=True)
row.prop(self, self.blfields['set_dim_symbol'], text='')
row.prop(self, self.blfields['set_dim_active_unit'], text='')
#################### ####################
# - Events # - Events
#################### ####################
@ -450,50 +428,47 @@ class FilterMathNode(base.MaxwellSimNode):
if not has_info: if not has_info:
return return
# Pin Dim by-Value: Synchronize Input Socket dim_0 = props['dim_0']
## -> The user will be given a socket w/correct mathtype, unit, etc. .
## -> Internally, this value will map to a particular index.
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
# Deduce Pinned Information
pinned_unit = info.dim_units[props['dim_0']]
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
wanted_mathtype = (
spux.MathType.Complex
if pinned_mathtype == spux.MathType.Complex
and spux.MathType.Complex in pinned_physical_type.valid_mathtypes
else spux.MathType.Real
)
# Get Current and Wanted Socket Defs # Loose Sockets: Pin Dim by-Value
## -> 'Value' may already exist. If not, all is well. ## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
props['operation'] is FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Value') current_bl_socket = self.loose_input_sockets.get('Value')
# Determine Whether to Construct
## -> If nothing needs to change, then nothing changes.
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != pinned_physical_type or current_bl_socket.physical_type != dim.physical_type
or current_bl_socket.mathtype != wanted_mathtype or current_bl_socket.mathtype != dim.mathtype
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Value': sockets.ExprSocketDef( 'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value, active_kind=ct.FlowKind.Value,
physical_type=pinned_physical_type, physical_type=dim.physical_type,
mathtype=wanted_mathtype, mathtype=dim.mathtype,
default_unit=pinned_unit, default_unit=dim.unit,
), ),
} }
# Pin Dim by-Index: Synchronize Input Socket # Loose Sockets: Pin Dim by-Value
## -> The user will be given a simple integer socket. ## -> Works with discrete points / labelled integers.
elif ( elif (
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None props['operation'] is FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
): ):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Axis') current_bl_socket = self.loose_input_sockets.get('Axis')
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.mathtype != spux.MathType.Integer or current_bl_socket.mathtype != spux.MathType.Integer
@ -505,28 +480,26 @@ class FilterMathNode(base.MaxwellSimNode):
) )
} }
# Set Dim: Synchronize Input Socket # Loose Sockets: Set Dim
## -> The user must provide a () -> array. ## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis. ## -> It must be of identical length to the replaced axis.
elif ( elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
props['operation'] is FilterOperation.SetDim dim = dim_0
and props['dim_0'] is not None
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
):
# Deduce Axis Information
current_bl_socket = self.loose_input_sockets.get('Dim') current_bl_socket = self.loose_input_sockets.get('Dim')
if ( if (
current_bl_socket is None current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.LazyValueFunc or current_bl_socket.active_kind != ct.FlowKind.Func
or current_bl_socket.mathtype != spux.MathType.Real or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical or current_bl_socket.mathtype != dim.mathtype
or current_bl_socket.physical_type != dim.physical_type
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef( 'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc, active_kind=ct.FlowKind.Func,
mathtype=spux.MathType.Real, physical_type=dim.physical_type,
physical_type=spux.PhysicalType.NonPhysical, mathtype=dim.mathtype,
default_unit=dim.unit,
show_func_ui=False,
show_info_columns=True, show_info_columns=True,
) )
} }
@ -536,42 +509,37 @@ class FilterMathNode(base.MaxwellSimNode):
self.loose_input_sockets = {} self.loose_input_sockets = {}
#################### ####################
# - FlowKind.Value|LazyValueFunc # - FlowKind.Value|Func
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
props={'operation', 'dim_0', 'dim_1', 'slice_tuple'}, props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, input_socket_kinds={'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}},
) )
def compute_lazy_value_func(self, props: dict, input_sockets: dict): def compute_lazy_func(self, props: dict, input_sockets: dict):
operation = props['operation'] operation = props['operation']
lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func) has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
# Dimension(s)
dim_0 = props['dim_0'] dim_0 = props['dim_0']
dim_1 = props['dim_1'] dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if ( if (
has_lazy_value_func has_lazy_func
and has_info and has_info
and operation is not None and operation is not None
and operation.are_dims_valid(info, dim_0, dim_1) and operation.are_dims_valid(info, dim_0, dim_1)
): ):
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
slice_tuple = (
props['slice_tuple']
if self.operation is FilterOperation.SliceIdx
else None
)
return lazy_value_func.compose_within( return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple), operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args, enclosing_func_args=operation.func_args,
supports_jax=True, supports_jax=True,
) )
@ -588,27 +556,26 @@ class FilterMathNode(base.MaxwellSimNode):
'dim_1', 'dim_1',
'operation', 'operation',
'slice_tuple', 'slice_tuple',
'set_dim_symbol',
'set_dim_active_unit',
}, },
input_sockets={'Expr', 'Dim'}, input_sockets={'Expr', 'Dim'},
input_socket_kinds={ input_socket_kinds={
'Expr': ct.FlowKind.Info, 'Expr': ct.FlowKind.Info,
'Dim': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params, ct.FlowKind.Info}, 'Dim': {ct.FlowKind.Func, ct.FlowKind.Params, ct.FlowKind.Info},
}, },
input_sockets_optional={'Dim': True}, input_sockets_optional={'Dim': True},
) )
def compute_info(self, props, input_sockets) -> ct.InfoFlow: def compute_info(self, props, input_sockets) -> ct.InfoFlow:
operation = props['operation'] operation = props['operation']
info = input_sockets['Expr'] info = input_sockets['Expr']
dim_coords = input_sockets['Dim'][ct.FlowKind.LazyValueFunc]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
dim_symbol = props['set_dim_symbol']
dim_active_unit = props['set_dim_active_unit']
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_dim_coords = not ct.FlowSignal.check(dim_coords)
# Dim (Op.SetDim)
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
has_dim_func = not ct.FlowSignal.check(dim_func)
has_dim_params = not ct.FlowSignal.check(dim_params) has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info) has_dim_info = not ct.FlowSignal.check(dim_info)
@ -619,44 +586,42 @@ class FilterMathNode(base.MaxwellSimNode):
if has_info and operation is not None: if has_info and operation is not None:
# Set Dimension: Retrieve Array # Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim: if props['operation'] is FilterOperation.SetDim:
new_dim = (
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
)
if ( if (
dim_0 is not None dim_0 is not None
# Check Replaced Dimension and new_dim is not None
and has_dim_coords
and len(dim_coords.func_args) == 1
and dim_coords.func_args[0] is spux.MathType.Integer
and not dim_coords.func_kwargs
and dim_coords.supports_jax
# Check Params
and has_dim_params
and len(dim_params.func_args) == 1
and not dim_params.func_kwargs
# Check Info
and has_dim_info and has_dim_info
and has_dim_params
# Check New Dimension Index Array Sizing
and len(dim_info.dims) == 1
and dim_info.output.rows == 1
and dim_info.output.cols == 1
# Check Lack of Params Symbols
and not dim_params.symbols
# Check Expr Dim | New Dim Compatibility
and info.has_idx_discrete(dim_0)
and dim_info.has_idx_discrete(new_dim)
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
): ):
# Retrieve Dimension Coordinate Array # Retrieve Dimension Coordinate Array
## -> It must be strictly compatible. ## -> It must be strictly compatible.
values = dim_coords.func_jax(int(dim_params.func_args[0])) values = dim_func.realize(dim_params, spux.UNITS_SI)
if (
len(values.shape) != 1
or values.shape[0] != info.dim_lens[dim_0]
):
return ct.FlowSignal.FlowPending
# Transform Info w/Corrected Dimension # Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced. ## -> The existing dimension will be replaced.
if dim_active_unit is not None:
dim_unit = spux.unit_str_to_unit(dim_active_unit)
else:
dim_unit = None
new_dim_idx = ct.ArrayFlow( new_dim_idx = ct.ArrayFlow(
values=values, values=values,
unit=dim_unit, unit=spux.convert_to_unit_system(
) dim_info.output.unit, spux.UNITS_SI
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)] ),
).rescale_to_unit(dim_info.output.unit)
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
return operation.transform_info( return operation.transform_info(
info, dim_0, dim_1, corrected_dim=corrected_dim info, dim_0, dim_1, replaced_dim=replaced_dim
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple) return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
@ -702,7 +667,7 @@ class FilterMathNode(base.MaxwellSimNode):
# Pin by-Value: Compute Nearest IDX # Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search. ## -> Presume a sorted index array to be able to use binary search.
if props['operation'] is FilterOperation.Pin and has_pinned_value: if props['operation'] is FilterOperation.Pin and has_pinned_value:
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True pinned_value, require_sorted=True
) )

View File

@ -23,7 +23,7 @@ import bpy
import jax.numpy as jnp import jax.numpy as jnp
import sympy as sp import sympy as sp
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
@ -153,40 +153,38 @@ class MapOperation(enum.StrEnum):
# - Ops from Shape # - Ops from Shape
#################### ####################
@staticmethod @staticmethod
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation MO = MapOperation
element_ops = [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match shape: match (info.output.rows, info.output.cols):
case 'noshape': case (1, 1):
return [] return element_ops
# By Number case (_, 1):
case None: return [*element_ops, MO.Norm2]
return [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
]
match len(shape): case (rows, cols) if rows == cols:
# By Vector ## TODO: Check hermitian/posdef for cholesky.
case 1: ## - Can we even do this with just the output symbol approach?
return [
MO.Norm2,
]
# By Matrix
case 2:
return [ return [
*element_ops,
MO.Det, MO.Det,
MO.Cond, MO.Cond,
MO.NormFro, MO.NormFro,
@ -201,6 +199,18 @@ class MapOperation(enum.StrEnum):
MO.Svd, MO.Svd,
] ]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
return [] return []
#################### ####################
@ -288,41 +298,76 @@ class MapOperation(enum.StrEnum):
def transform_info(self, info: ct.InfoFlow): def transform_info(self, info: ct.InfoFlow):
MO = MapOperation MO = MapOperation
return { return {
# By Number # By Number
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real), MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector # By Vector
MO.Norm2: lambda: info.collapse_output( MO.Norm2: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('v', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=info.output_unit, cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
# By Matrix # By Matrix
MO.Det: lambda: info.collapse_output( MO.Det: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), rows=1,
collapsed_mathtype=info.output_mathtype, cols=1,
collapsed_unit=info.output_unit,
), ),
MO.Cond: lambda: info.collapse_output( MO.Cond: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=None, cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
), ),
MO.NormFro: lambda: info.collapse_output( MO.NormFro: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Real,
collapsed_mathtype=spux.MathType.Real, rows=1,
collapsed_unit=info.output_unit, cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
MO.Rank: lambda: info.collapse_output( MO.Rank: lambda: info.update_output(
collapsed_name=MO.to_name(self).replace('V', info.output_name), mathtype=spux.MathType.Integer,
collapsed_mathtype=spux.MathType.Integer, rows=1,
collapsed_unit=None, cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
), ),
## TODO: Matrix -> Vec # Matrix -> Vector ## TODO: ALL OF THESE
## TODO: Matrix -> Matrices MO.Diag: lambda: info,
}.get(self, lambda: info)() MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()
#################### ####################
@ -401,7 +446,7 @@ class MapMathNode(base.MaxwellSimNode):
The name and type of the available symbol is clearly shown, and most valid `sympy` expressions that you would expect to work, should work. The name and type of the available symbol is clearly shown, and most valid `sympy` expressions that you would expect to work, should work.
Use of expressions generally imposes no performance penalty: Just like the baked-in operations, it is compiled to a high-performance `jax` function. Use of expressions generally imposes no performance penalty: Just like the baked-in operations, it is compiled to a high-performance `jax` function.
Thus, it participates in the `ct.FlowKind.LazyValueFunc` composition chain. Thus, it participates in the `ct.FlowKind.Func` composition chain.
Attributes: Attributes:
@ -412,10 +457,10 @@ class MapMathNode(base.MaxwellSimNode):
bl_label = 'Map Math' bl_label = 'Map Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
#################### ####################
@ -435,29 +480,26 @@ class MapMathNode(base.MaxwellSimNode):
) )
if has_info and not info_pending: if has_info and not info_pending:
self.expr_output_shape = bl_cache.Signal.InvalidateCache self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None: def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True) info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
if has_info: if has_info:
return info.output_shape return info
return None
return 'noshape'
operation: MapOperation = bl_cache.BLField( operation: MapOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(), enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_output_shape'}, cb_depends_on={'expr_info'},
) )
def search_operations(self) -> list[ct.BLEnumElement]: def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_output_shape != 'noshape': if self.info is not None:
return [ return [
operation.bl_enum_element(i) operation.bl_enum_element(i)
for i, operation in enumerate( for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
MapOperation.by_element_shape(self.expr_output_shape)
)
] ]
return [] return []
@ -474,7 +516,7 @@ class MapMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - FlowKind.Value|LazyValueFunc # - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -495,18 +537,19 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
props={'operation'}, props={'operation'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': ct.FlowKind.LazyValueFunc, 'Expr': ct.FlowKind.Func,
}, },
) )
def compute_func( def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
self, props, input_sockets
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
operation = props['operation'] operation = props['operation']
expr = input_sockets['Expr'] expr = input_sockets['Expr']
@ -520,7 +563,7 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - FlowKind.Info|Params # - FlowKind.Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -540,6 +583,9 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,

View File

@ -261,12 +261,12 @@ class OperateMathNode(base.MaxwellSimNode):
bl_label = 'Operate Math' bl_label = 'Operate Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc, show_info_columns=True active_kind=ct.FlowKind.Func, show_info_columns=True
), ),
} }
@ -344,7 +344,7 @@ class OperateMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - FlowKind.Value|LazyValueFunc # - FlowKind.Value|Func
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -373,12 +373,12 @@ class OperateMathNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
props={'operation'}, props={'operation'},
input_sockets={'Expr L', 'Expr R'}, input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={ input_socket_kinds={
'Expr L': ct.FlowKind.LazyValueFunc, 'Expr L': ct.FlowKind.Func,
'Expr R': ct.FlowKind.LazyValueFunc, 'Expr R': ct.FlowKind.Func,
}, },
) )
def compose_func(self, props: dict, input_sockets: dict): def compose_func(self, props: dict, input_sockets: dict):

View File

@ -97,7 +97,7 @@ class ReduceMathNode(base.MaxwellSimNode):
'Data', 'Data',
props={'active_socket_set', 'operation'}, props={'active_socket_set', 'operation'},
input_sockets={'Data', 'Axis', 'Reducer'}, input_sockets={'Data', 'Axis', 'Reducer'},
input_socket_kinds={'Reducer': ct.FlowKind.LazyValueFunc}, input_socket_kinds={'Reducer': ct.FlowKind.Func},
input_sockets_optional={'Reducer': True}, input_sockets_optional={'Reducer': True},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_data(self, props: dict, input_sockets: dict):

View File

@ -107,32 +107,31 @@ class TransformOperation(enum.StrEnum):
# Covariant Transform # Covariant Transform
## Freq <-> VacWL ## Freq <-> VacWL
for dim_name in info.dim_names: for dim in info.dims:
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL) operations.append(TO.FreqToVacWL)
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: if dim.physical_type == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq) operations.append(TO.VacWLToFreq)
# Fold # Fold
## (Last) Int Dim (=2) to Complex ## (Last) Int Dim (=2) to Complex
if len(info.dim_names) >= 1: if len(info.dims) >= 1:
last_dim_name = info.dim_names[-1] if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
operations.append(TO.IntDimToComplex) operations.append(TO.IntDimToComplex)
## To Vector ## To Vector
if len(info.dim_names) >= 1: if len(info.dims) >= 1:
operations.append(TO.DimToVec) operations.append(TO.DimToVec)
## To Matrix ## To Matrix
if len(info.dim_names) >= 2: # noqa: PLR2004 if len(info.dims) >= 2: # noqa: PLR2004
operations.append(TO.DimsToMat) operations.append(TO.DimsToMat)
# Fourier # Fourier
## 1D Fourier ## 1D Fourier
if info.dim_names: if info.dims:
last_physical_type = info.dim_physical_types[info.dim_names[-1]] last_physical_type = info.last_dim.physical_type
if last_physical_type == spux.PhysicalType.Time: if last_physical_type == spux.PhysicalType.Time:
operations.append(TO.FFT1D) operations.append(TO.FFT1D)
if last_physical_type == spux.PhysicalType.Freq: if last_physical_type == spux.PhysicalType.Freq:
@ -188,15 +187,15 @@ class TransformOperation(enum.StrEnum):
unit: spux.Unit | None = None, unit: spux.Unit | None = None,
) -> ct.InfoFlow | None: ) -> ct.InfoFlow | None:
TO = TransformOperation TO = TransformOperation
if not info.dim_names: if not info.dims:
return None return None
return { return {
# Index # Covariant Transform
TO.FreqToVacWL: lambda: info.replace_dim( TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.dim_names[-1]), (f_dim := info.last_dim),
[ [
'wl', sim_symbols.wl(spu.nanometer),
info.dim_idx[f_dim].rescale( info.dims[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spu.nanometer, new_unit=spu.nanometer,
@ -204,10 +203,10 @@ class TransformOperation(enum.StrEnum):
], ],
), ),
TO.VacWLToFreq: lambda: info.replace_dim( TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.dim_names[-1]), (wl_dim := info.last_dim),
[ [
'f', sim_symbols.freq(spux.THz),
info.dim_idx[wl_dim].rescale( info.dims[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el, lambda el: sci_constants.vac_speed_of_light / el,
reverse=True, reverse=True,
new_unit=spux.THz, new_unit=spux.THz,
@ -215,26 +214,24 @@ class TransformOperation(enum.StrEnum):
], ],
), ),
# Fold # Fold
TO.IntDimToComplex: lambda: info.delete_dimension( TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
info.dim_names[-1] mathtype=spux.MathType.Complex
).set_output_mathtype(spux.MathType.Complex), ),
TO.DimToVec: lambda: info.shift_last_input, TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input, TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
# Fourier # Fourier
TO.FFT1D: lambda: info.replace_dim( TO.FFT1D: lambda: info.replace_dim(
info.dim_names[-1], info.last_dim,
[ [
'f', sim_symbols.freq(spux.THz),
ct.LazyArrayRangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz), None,
], ],
), ),
TO.InvFFT1D: info.replace_dim( TO.InvFFT1D: info.replace_dim(
info.dim_names[-1], info.last_dim,
[ [
't', sim_symbols.t(spu.second),
ct.LazyArrayRangeFlow( None,
start=0, stop=sp.oo, steps=0, unit=spu.second
),
], ],
), ),
}.get(self, lambda: info)() }.get(self, lambda: info)()
@ -260,10 +257,10 @@ class TransformMathNode(base.MaxwellSimNode):
bl_label = 'Transform Math' bl_label = 'Transform Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
#################### ####################
@ -325,7 +322,7 @@ class TransformMathNode(base.MaxwellSimNode):
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - Compute: LazyValueFunc / Array # - Compute: Func / Array
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -348,16 +345,14 @@ class TransformMathNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
props={'operation'}, props={'operation'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': ct.FlowKind.LazyValueFunc, 'Expr': ct.FlowKind.Func,
}, },
) )
def compute_func( def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
self, props, input_sockets
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
operation = props['operation'] operation = props['operation']
expr = input_sockets['Expr'] expr = input_sockets['Expr']

View File

@ -38,7 +38,6 @@ class VizMode(enum.StrEnum):
**NOTE**: >1D output dimensions currently have no viz. **NOTE**: >1D output dimensions currently have no viz.
Plots for `() -> `: Plots for `() -> `:
- Hist1D: Bin-summed distribution.
- BoxPlot1D: Box-plot describing the distribution. - BoxPlot1D: Box-plot describing the distribution.
Plots for `() -> `: Plots for `() -> `:
@ -61,7 +60,6 @@ class VizMode(enum.StrEnum):
- Heatmap3D: Colormapped field with value at each voxel. - Heatmap3D: Colormapped field with value at each voxel.
""" """
Hist1D = enum.auto()
BoxPlot1D = enum.auto() BoxPlot1D = enum.auto()
Curve2D = enum.auto() Curve2D = enum.auto()
@ -78,42 +76,38 @@ class VizMode(enum.StrEnum):
@staticmethod @staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
EMPTY = ()
Z = spux.MathType.Integer Z = spux.MathType.Integer
R = spux.MathType.Real R = spux.MathType.Real
VM = VizMode VM = VizMode
valid_viz_modes = { return {
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D], ((Z), (1, 1, R)): [
((Z), (None, R)): [
VM.Hist1D,
VM.BoxPlot1D, VM.BoxPlot1D,
], ],
((R,), (None, R)): [ ((R,), (1, 1, R)): [
VM.Curve2D, VM.Curve2D,
VM.Points2D, VM.Points2D,
VM.Bar, VM.Bar,
], ],
((R, Z), (None, R)): [ ((R, Z), (1, 1, R)): [
VM.Curves2D, VM.Curves2D,
VM.FilledCurves2D, VM.FilledCurves2D,
], ],
((R, R), (None, R)): [ ((R, R), (1, 1, R)): [
VM.Heatmap2D, VM.Heatmap2D,
], ],
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D], ((R, R, R), (1, 1, R)): [
VM.SqueezedHeatmap2D,
VM.Heatmap3D,
],
}.get( }.get(
( (
tuple(info.dim_mathtypes.values()), tuple([dim.mathtype for dim in info.dims.values()]),
(info.output_shape, info.output_mathtype), (info.output.rows, info.output.cols, info.output.mathtype),
) ),
[],
) )
if valid_viz_modes is None:
return []
return valid_viz_modes
@staticmethod @staticmethod
def to_plotter( def to_plotter(
value: typ.Self, value: typ.Self,
@ -121,7 +115,6 @@ class VizMode(enum.StrEnum):
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
]: ]:
return { return {
VizMode.Hist1D: image_ops.plot_hist_1d,
VizMode.BoxPlot1D: image_ops.plot_box_plot_1d, VizMode.BoxPlot1D: image_ops.plot_box_plot_1d,
VizMode.Curve2D: image_ops.plot_curve_2d, VizMode.Curve2D: image_ops.plot_curve_2d,
VizMode.Points2D: image_ops.plot_points_2d, VizMode.Points2D: image_ops.plot_points_2d,
@ -136,7 +129,6 @@ class VizMode(enum.StrEnum):
@staticmethod @staticmethod
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
return { return {
VizMode.Hist1D: 'Histogram',
VizMode.BoxPlot1D: 'Box Plot', VizMode.BoxPlot1D: 'Box Plot',
VizMode.Curve2D: 'Curve', VizMode.Curve2D: 'Curve',
VizMode.Points2D: 'Points', VizMode.Points2D: 'Points',
@ -164,7 +156,6 @@ class VizTarget(enum.StrEnum):
@staticmethod @staticmethod
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
return { return {
VizMode.Hist1D: [VizTarget.Plot2D],
VizMode.BoxPlot1D: [VizTarget.Plot2D], VizMode.BoxPlot1D: [VizTarget.Plot2D],
VizMode.Curve2D: [VizTarget.Plot2D], VizMode.Curve2D: [VizTarget.Plot2D],
VizMode.Points2D: [VizTarget.Plot2D], VizMode.Points2D: [VizTarget.Plot2D],
@ -209,7 +200,7 @@ class VizNode(base.MaxwellSimNode):
#################### ####################
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc, active_kind=ct.FlowKind.Func,
default_symbols=[sim_symbols.x], default_symbols=[sim_symbols.x],
default_value=2 * sim_symbols.x.sp_symbol, default_value=2 * sim_symbols.x.sp_symbol,
), ),
@ -333,35 +324,20 @@ class VizNode(base.MaxwellSimNode):
## -> This happens if Params contains not-yet-realized symbols. ## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols: if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != { if set(self.loose_input_sockets) != {
sym.name for sym in params.symbols if sym.name in info.dim_names dim.name for dim in params.symbols if dim in info.dims
}: }:
self.loose_input_sockets = { self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef( dim_name: sockets.ExprSocketDef(**expr_info)
active_kind=ct.FlowKind.LazyArrayRange, for dim_name, expr_info in params.sym_expr_infos(
size=spux.NumberSize1D.Scalar, info, use_range=True
mathtype=info.dim_mathtypes[sym.name], ).items()
physical_type=info.dim_physical_types[sym.name],
default_min=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].start).is_infinite
else sp.S(0)
),
default_max=(
info.dim_idx[sym.name].start
if not sp.S(info.dim_idx[sym.name].stop).is_infinite
else sp.S(1)
),
default_steps=50,
)
for sym in params.sorted_symbols
if sym.name in info.dim_names
} }
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
##################### #####################
## - Plotting ## - FlowKind.Value
##################### #####################
@events.computes_output_socket( @events.computes_output_socket(
'Preview', 'Preview',
@ -370,37 +346,38 @@ class VizNode(base.MaxwellSimNode):
props={'viz_mode', 'viz_target', 'colormap'}, props={'viz_mode', 'viz_target', 'colormap'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params} 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
}, },
all_loose_input_sockets=True, all_loose_input_sockets=True,
) )
def compute_dummy_value(self, props, input_sockets, loose_input_sockets): def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
"""Needed for the plot to regenerate in the viewer."""
return ct.FlowSignal.NoFlow return ct.FlowSignal.NoFlow
#####################
## - On Show Plot
#####################
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'}, props={'viz_mode', 'viz_target', 'colormap'},
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={ input_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params} 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
}, },
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
all_loose_input_sockets=True, all_loose_input_sockets=True,
stop_propagation=True, stop_propagation=True,
) )
def on_show_plot( def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems self, managed_objs, props, input_sockets, loose_input_sockets
): ) -> None:
# Retrieve Inputs # Retrieve Inputs
lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc] lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info] info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params] params = input_sockets['Expr'][ct.FlowKind.Params]
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params) has_params = not ct.FlowSignal.check(params)
# Invalid Mode | Target
## -> To limit branching, return now if things aren't right.
if ( if (
not has_info not has_info
or not has_params or not has_params
@ -409,54 +386,43 @@ class VizNode(base.MaxwellSimNode):
): ):
return return
# Compute LazyArrayRanges for Symbols from Loose Sockets # Compute Ranges for Symbols from Loose Sockets
## -> These are the concrete values of the symbol for plotting.
## -> In a quite nice turn of events, all this is cached lookups. ## -> In a quite nice turn of events, all this is cached lookups.
## -> ...Unless something changed, in which case, well. It changed. ## -> ...Unless something changed, in which case, well. It changed.
symbol_values = { symbol_array_values = {
sym: ( sim_syms: (
loose_input_sockets[sym.name] loose_input_sockets[sim_syms]
.realize_array.rescale_to_unit(info.dim_units[sym.name]) .rescale_to_unit(sim_syms.unit)
.values .realize_array
) )
for sym in params.sorted_symbols for sim_syms in params.sorted_symbols
} }
data = lazy_func.realize(params, symbol_values=symbol_array_values)
# Realize LazyValueFunc w/Symbolic Values, Unit System
## -> This gives us the actual plot data!
data = lazy_value_func.func_jax(
*params.scaled_func_args(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
**params.scaled_func_kwargs(
unit_systems['BlenderUnits'], symbol_values=symbol_values
),
)
# Replace InfoFlow Indices w/Realized Symbolic Ranges # Replace InfoFlow Indices w/Realized Symbolic Ranges
## -> This ensures correct axis scaling. ## -> This ensures correct axis scaling.
if params.symbols: if params.symbols:
info = info.rescale_dim_idxs(loose_input_sockets) info = info.replace_dims(symbol_array_values)
# Visualize by-Target match props['viz_target']:
if props['viz_target'] == VizTarget.Plot2D: case VizTarget.Plot2D:
managed_objs['plot'].mpl_plot_to_image( managed_objs['plot'].mpl_plot_to_image(
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax), lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
bl_select=True, bl_select=True,
) )
if props['viz_target'] == VizTarget.Pixels: case VizTarget.Pixels:
managed_objs['plot'].map_2d_to_image( managed_objs['plot'].map_2d_to_image(
data, data,
colormap=props['colormap'], colormap=props['colormap'],
bl_select=True, bl_select=True,
) )
if props['viz_target'] == VizTarget.PixelsPlane: case VizTarget.PixelsPlane:
raise NotImplementedError raise NotImplementedError
if props['viz_target'] == VizTarget.Voxels: case VizTarget.Voxels:
raise NotImplementedError raise NotImplementedError
#################### ####################

View File

@ -67,7 +67,7 @@ ManagedObjName: typ.TypeAlias = str
PropName: typ.TypeAlias = str PropName: typ.TypeAlias = str
def event_decorator( def event_decorator( # noqa: PLR0913
event: ct.FlowEvent, event: ct.FlowEvent,
callback_info: EventCallbackInfo | None, callback_info: EventCallbackInfo | None,
stop_propagation: bool = False, stop_propagation: bool = False,
@ -91,31 +91,42 @@ def event_decorator(
scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
): ):
"""Returns a decorator for a method of `MaxwellSimNode`, declaring it as able respond to events passing through a node. """Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
Should generally be used via a high-level decorator such as `on_value_changed`.
For more about how event methods are actually registered and run, please refer to the documentation of `MaxwellSimNode`.
Parameters: Parameters:
event: A name describing which event the decorator should respond to. event: A name describing which event the decorator should respond to.
Set to `return_method.event`
callback_info: A dictionary that provides the caller with additional per-`event` information. callback_info: A dictionary that provides the caller with additional per-`event` information.
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback. This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
props: Set of `props` to compute, then pass to the decorated method.
stop_propagation: Whether or stop propagating the event through the graph after encountering this method. stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
Other methods defined on the same node will still run. Other methods defined on the same node will still run.
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method. managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
props: Set of `props` to compute, then pass to the decorated method.
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method. input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
input_sockets_optional: Whether an input socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
input_socket_kinds: The `ct.FlowKind` to compute per-input-socket. input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
If an input socket isn't specified, it defaults to `ct.FlowKind.Value`. If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method. output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
output_sockets_optional: Whether an output socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method. all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
Used when the names of the loose input sockets are unknown, but all of their values are needed. Used when the names of the loose input sockets are unknown, but all of their values are needed.
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method. all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
Used when the names of the loose output sockets are unknown, but all of their values are needed. Used when the names of the loose output sockets are unknown, but all of their values are needed.
unit_systems: String identifiers under which to load a unit system, made available to the method.
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
Returns: Returns:
A decorator, which can be applied to a method of `MaxwellSimNode`. A decorator, which can be applied to a method of `MaxwellSimNode` to make it an "event method".
When a `MaxwellSimNode` subclass initializes, such a decorated method will be picked up on.
When `event` passes through the node, then `callback_info` is used to determine
""" """
req_params = ( req_params = (
{'self'} {'self'}
@ -375,7 +386,6 @@ def on_value_changed(
) )
## TODO: Change name to 'on_output_requested'
def computes_output_socket( def computes_output_socket(
output_socket_name: ct.SocketName | None, output_socket_name: ct.SocketName | None,
kind: ct.FlowKind = ct.FlowKind.Value, kind: ct.FlowKind = ct.FlowKind.Value,

View File

@ -29,12 +29,12 @@ class ExprConstantNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc, active_kind=ct.FlowKind.Func,
), ),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef( 'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc, active_kind=ct.FlowKind.Func,
show_info_columns=True, show_info_columns=True,
), ),
} }
@ -58,12 +58,12 @@ class ExprConstantNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
# Trigger # Trigger
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
# Loaded # Loaded
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc}, input_socket_kinds={'Expr': ct.FlowKind.Func},
) )
def compute_lazy_value_func(self, input_sockets: dict) -> typ.Any: def compute_lazy_func(self, input_sockets: dict) -> typ.Any:
return input_sockets['Expr'] return input_sockets['Expr']
#################### ####################

View File

@ -19,14 +19,10 @@ import typing as typ
from pathlib import Path from pathlib import Path
import bpy import bpy
import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import pandas as pd
import sympy as sp import sympy as sp
import tidy3d as td import tidy3d as td
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
@ -35,112 +31,6 @@ from ... import base, events
log = logger.get(__name__) log = logger.get(__name__)
####################
# - Data File Extensions
####################
_DATA_FILE_EXTS = {
'.txt',
'.txt.gz',
'.csv',
'.npy',
}
class DataFileExt(enum.StrEnum):
Txt = enum.auto()
TxtGz = enum.auto()
Csv = enum.auto()
Npy = enum.auto()
####################
# - Enum Elements
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return DataFileExt(v).extension
@staticmethod
def to_icon(v: typ.Self) -> str:
return ''
####################
# - Computed Properties
####################
@property
def extension(self) -> str:
"""Map to the actual string extension."""
E = DataFileExt
return {
E.Txt: '.txt',
E.TxtGz: '.txt.gz',
E.Csv: '.csv',
E.Npy: '.npy',
}[self]
@property
def loader(self) -> typ.Callable[[Path], jtyp.Shaped[jtyp.Array, '...']]:
def load_txt(path: Path):
return jnp.asarray(np.loadtxt(path))
def load_csv(path: Path):
return jnp.asarray(pd.read_csv(path).values)
def load_npy(path: Path):
return jnp.load(path)
E = DataFileExt
return {
E.Txt: load_txt,
E.TxtGz: load_txt,
E.Csv: load_csv,
E.Npy: load_npy,
}[self]
@property
def loader_is_jax_compatible(self) -> bool:
E = DataFileExt
return {
E.Txt: True,
E.TxtGz: True,
E.Csv: False,
E.Npy: True,
}[self]
####################
# - Creation
####################
@staticmethod
def from_ext(ext: str) -> typ.Self | None:
return {
_ext: _data_file_ext
for _data_file_ext, _ext in {
k: k.extension for k in list(DataFileExt)
}.items()
}.get(ext)
@staticmethod
def from_path(path: Path) -> typ.Self | None:
if DataFileExt.is_path_compatible(path):
data_file_ext = DataFileExt.from_ext(''.join(path.suffixes))
if data_file_ext is not None:
return data_file_ext
msg = f'DataFileExt: Path "{path}" is compatible, but could not find valid extension'
raise RuntimeError(msg)
return None
####################
# - Compatibility
####################
@staticmethod
def is_ext_compatible(ext: str):
return ext in _DATA_FILE_EXTS
@staticmethod
def is_path_compatible(path: Path):
return path.is_file() and DataFileExt.is_ext_compatible(''.join(path.suffixes))
#################### ####################
# - Node # - Node
@ -153,7 +43,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
'File Path': sockets.FilePathSocketDef(), 'File Path': sockets.FilePathSocketDef(),
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
} }
#################### ####################
@ -168,10 +58,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102 def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_file_path = not ct.FlowSignal.check(input_sockets['File Path']) has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
has_file_path = ct.FlowSignal.check_single(
input_sockets['File Path'], ct.FlowSignal.FlowPending
)
if has_file_path: if has_file_path:
self.file_path = bl_cache.Signal.InvalidateCache self.file_path = bl_cache.Signal.InvalidateCache
@ -188,10 +74,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
return None return None
@bl_cache.cached_bl_property(depends_on={'file_path'}) @bl_cache.cached_bl_property(depends_on={'file_path'})
def data_file_ext(self) -> DataFileExt | None: def data_file_format(self) -> ct.DataFileFormat | None:
"""Retrieve the file extension by concatenating all suffixes.""" """Retrieve the file extension by concatenating all suffixes."""
if self.file_path is not None: if self.file_path is not None:
return DataFileExt.from_path(self.file_path) return ct.DataFileFormat.from_path(self.file_path)
return None return None
#################### ####################
@ -201,11 +87,93 @@ class DataFileImporterNode(base.MaxwellSimNode):
def expr_info(self) -> ct.InfoFlow | None: def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the output expression's `InfoFlow`.""" """Retrieve the output expression's `InfoFlow`."""
info = self.compute_output('Expr', kind=ct.FlowKind.Info) info = self.compute_output('Expr', kind=ct.FlowKind.Info)
has_info = not ct.FlowKind.check(info) has_info = not ct.FlowSignal.check(info)
if has_info: if has_info:
return info return info
return None return None
####################
# - Info Guides
####################
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
output_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
output_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'output_physical_type'},
)
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_0_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
cb_depends_on={'dim_0_physical_type'},
)
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerB
)
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_1_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
cb_depends_on={'dim_1_physical_type'},
)
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerC
)
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_2_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
cb_depends_on={'dim_2_physical_type'},
)
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerD
)
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
spux.PhysicalType.NonPhysical
)
dim_3_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
cb_depends_on={'dim_3_physical_type'},
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
def dim(self, i: int):
dim_name = getattr(self, f'dim_{i}_name')
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
dim_unit = getattr(self, f'dim_{i}_unit')
return sim_symbols.SimSymbol(
sym_name=dim_name,
mathtype=dim_mathtype,
physical_type=dim_physical_type,
unit=spux.unit_str_to_unit(dim_unit),
)
#################### ####################
# - UI # - UI
#################### ####################
@ -216,13 +184,13 @@ class DataFileImporterNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header. Called by Blender to determine the text to place in the node's header.
""" """
if self.file_path is not None: if self.file_path is not None:
return 'Load File: ' + self.file_path.name return 'Load: ' + self.file_path.name
return self.bl_label return self.bl_label
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Show information about the loaded file.""" """Show information about the loaded file."""
if self.data_file_ext is not None: if self.data_file_format is not None:
box = layout.box() box = layout.box()
row = box.row() row = box.row()
row.alignment = 'CENTER' row.alignment = 'CENTER'
@ -233,24 +201,27 @@ class DataFileImporterNode(base.MaxwellSimNode):
row.label(text=self.file_path.name) row.label(text=self.file_path.name)
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
pass """Draw loaded properties."""
for i in range(len(self.expr_info.dims)):
col = layout.column(align=True)
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text=f'Load Dim {i}')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
row = col.row(align=True)
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
#################### ####################
# - Events # - FlowKind.Array|Func
####################
@events.on_value_changed(
socket_name='File Path',
input_sockets={'File Path'},
)
def on_file_changed(self, input_sockets) -> None:
pass
####################
# - FlowKind.Array|LazyValueFunc
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Func,
input_sockets={'File Path'}, input_sockets={'File Path'},
) )
def compute_func(self, input_sockets: dict) -> td.Simulation: def compute_func(self, input_sockets: dict) -> td.Simulation:
@ -264,20 +235,20 @@ class DataFileImporterNode(base.MaxwellSimNode):
has_file_path = not ct.FlowSignal.check(input_sockets['File Path']) has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
if has_file_path: if has_file_path:
data_file_ext = DataFileExt.from_path(file_path) data_file_format = ct.DataFileFormat.from_path(file_path)
if data_file_ext is not None: if data_file_format is not None:
# Jax Compatibility: Lazy Data Loading # Jax Compatibility: Lazy Data Loading
## -> Delay loading of data from file as long as we can. ## -> Delay loading of data from file as long as we can.
if data_file_ext.loader_is_jax_compatible: if data_file_format.loader_is_jax_compatible:
return ct.LazyValueFuncFlow( return ct.FuncFlow(
func=lambda: data_file_ext.loader(file_path), func=lambda: data_file_format.loader(file_path),
supports_jax=True, supports_jax=True,
) )
# No Jax Compatibility: Eager Data Loading # No Jax Compatibility: Eager Data Loading
## -> Load the data now and bind it. ## -> Load the data now and bind it.
data = data_file_ext.loader(file_path) data = data_file_format.loader(file_path)
return ct.LazyValueFuncFlow(func=lambda: data, supports_jax=True) return ct.FuncFlow(func=lambda: data, supports_jax=True)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -299,10 +270,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
# Loaded
props={'output_name', 'output_physical_type', 'output_unit'},
output_sockets={'Expr'}, output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc}, output_socket_kinds={'Expr': ct.FlowKind.Func},
) )
def compute_info(self, output_sockets) -> ct.InfoFlow: def compute_info(self, props, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape. """Declare an `InfoFlow` based on the data shape.
This currently requires computing the data. This currently requires computing the data.
@ -321,26 +294,24 @@ class DataFileImporterNode(base.MaxwellSimNode):
# Deduce Dimensionality # Deduce Dimensionality
_shape = data.shape _shape = data.shape
shape = _shape if _shape is not None else () shape = _shape if _shape is not None else ()
dim_names = [f'a{i}' for i in range(len(shape))] dim_syms = [self.dim(i) for i in range(len(shape))]
# Return InfoFlow # Return InfoFlow
## -> TODO: How to interpret the data should be user-defined.
## -> -- This may require those nice dynamic symbols.
return ct.InfoFlow( return ct.InfoFlow(
dim_names=dim_names, ## TODO: User dims={
dim_idx={ dim_sym: ct.RangeFlow(
dim_name: ct.LazyArrayRangeFlow( start=sp.S(0),
start=sp.S(0), ## TODO: User stop=sp.S(shape[i] - 1),
stop=sp.S(shape[i] - 1), ## TODO: User steps=shape[i],
steps=shape[dim_names.index(dim_name)], unit=self.dim(i).unit,
unit=None, ## TODO: User
) )
for i, dim_name in enumerate(dim_names) for i, dim_sym in enumerate(dim_syms)
}, },
output_name='_', output=sim_symbols.SimSymbol(
output_shape=None, sym_name=props['output_name'],
output_mathtype=spux.MathType.Real, ## TODO: User mathtype=props['output_mathtype'],
output_unit=None, ## TODO: User physical_type=props['output_physical_type'],
),
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending

View File

@ -74,11 +74,11 @@ class SceneNode(base.MaxwellSimNode):
return bpy.context.scene.frame_current return bpy.context.scene.frame_current
@property @property
def scene_frame_range(self) -> ct.LazyArrayRangeFlow: def scene_frame_range(self) -> ct.RangeFlow:
"""Retrieve the current start/end frame of the scene, with `steps` corresponding to single-frame steps.""" """Retrieve the current start/end frame of the scene, with `steps` corresponding to single-frame steps."""
frame_start = bpy.context.scene.frame_start frame_start = bpy.context.scene.frame_start
frame_stop = bpy.context.scene.frame_end frame_stop = bpy.context.scene.frame_end
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=frame_start, start=frame_start,
stop=frame_stop, stop=frame_stop,
steps=frame_stop - frame_start + 1, steps=frame_stop - frame_start + 1,

View File

@ -100,30 +100,26 @@ class WaveConstantNode(base.MaxwellSimNode):
run_on_init=True, run_on_init=True,
) )
def on_use_range_changed(self, props: dict) -> None: def on_use_range_changed(self, props: dict) -> None:
"""Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.LazyArrayRange`.""" """Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.Range`."""
if self.inputs.get('WL') is not None: if self.inputs.get('WL') is not None:
active_input = self.inputs['WL'] active_input = self.inputs['WL']
else: else:
active_input = self.inputs['Freq'] active_input = self.inputs['Freq']
# Modify Active Kind(s) # Modify Active Kind(s)
## Input active_kind -> Value/LazyArrayRange ## Input active_kind -> Value/Range
active_input_uses_range = active_input.active_kind == ct.FlowKind.LazyArrayRange active_input_uses_range = active_input.active_kind == ct.FlowKind.Range
if active_input_uses_range != props['use_range']: if active_input_uses_range != props['use_range']:
active_input.active_kind = ( active_input.active_kind = (
ct.FlowKind.LazyArrayRange if props['use_range'] else ct.FlowKind.Value ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
) )
## Output active_kind -> Value/LazyArrayRange ## Output active_kind -> Value/Range
for active_output in self.outputs.values(): for active_output in self.outputs.values():
active_output_uses_range = ( active_output_uses_range = active_output.active_kind == ct.FlowKind.Range
active_output.active_kind == ct.FlowKind.LazyArrayRange
)
if active_output_uses_range != props['use_range']: if active_output_uses_range != props['use_range']:
active_output.active_kind = ( active_output.active_kind = (
ct.FlowKind.LazyArrayRange ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
if props['use_range']
else ct.FlowKind.Value
) )
#################### ####################
@ -161,11 +157,11 @@ class WaveConstantNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'WL', 'WL',
kind=ct.FlowKind.LazyArrayRange, kind=ct.FlowKind.Range,
input_sockets={'WL', 'Freq'}, input_sockets={'WL', 'Freq'},
input_socket_kinds={ input_socket_kinds={
'WL': ct.FlowKind.LazyArrayRange, 'WL': ct.FlowKind.Range,
'Freq': ct.FlowKind.LazyArrayRange, 'Freq': ct.FlowKind.Range,
}, },
input_sockets_optional={'WL': True, 'Freq': True}, input_sockets_optional={'WL': True, 'Freq': True},
) )
@ -176,7 +172,7 @@ class WaveConstantNode(base.MaxwellSimNode):
return input_sockets['WL'] return input_sockets['WL']
freq = input_sockets['Freq'] freq = input_sockets['Freq']
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=spux.scale_to_unit( start=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (freq.stop * freq.unit), spu.um sci_constants.vac_speed_of_light / (freq.stop * freq.unit), spu.um
), ),
@ -190,11 +186,11 @@ class WaveConstantNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Freq', 'Freq',
kind=ct.FlowKind.LazyArrayRange, kind=ct.FlowKind.Range,
input_sockets={'WL', 'Freq'}, input_sockets={'WL', 'Freq'},
input_socket_kinds={ input_socket_kinds={
'WL': ct.FlowKind.LazyArrayRange, 'WL': ct.FlowKind.Range,
'Freq': ct.FlowKind.LazyArrayRange, 'Freq': ct.FlowKind.Range,
}, },
input_sockets_optional={'WL': True, 'Freq': True}, input_sockets_optional={'WL': True, 'Freq': True},
) )
@ -205,7 +201,7 @@ class WaveConstantNode(base.MaxwellSimNode):
return input_sockets['Freq'] return input_sockets['Freq']
wl = input_sockets['WL'] wl = input_sockets['WL']
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=spux.scale_to_unit( start=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (wl.stop * wl.unit), spux.THz sci_constants.vac_speed_of_light / (wl.stop * wl.unit), spux.THz
), ),

View File

@ -115,11 +115,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(), 'Medium': sockets.MaxwellMediumSocketDef(),
'Valid Freqs': sockets.ExprSocketDef( 'Valid Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
), ),
'Valid WLs': sockets.ExprSocketDef( 'Valid WLs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Length, physical_type=spux.PhysicalType.Length,
), ),
} }
@ -254,11 +254,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Valid Freqs', 'Valid Freqs',
kind=ct.FlowKind.LazyArrayRange, kind=ct.FlowKind.Range,
props={'freq_range'}, props={'freq_range'},
) )
def compute_valid_freqs_lazy(self, props) -> sp.Expr: def compute_valid_freqs_lazy(self, props) -> sp.Expr:
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=props['freq_range'][0] / spux.THz, start=props['freq_range'][0] / spux.THz,
stop=props['freq_range'][1] / spux.THz, stop=props['freq_range'][1] / spux.THz,
steps=0, steps=0,
@ -268,11 +268,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Valid WLs', 'Valid WLs',
kind=ct.FlowKind.LazyArrayRange, kind=ct.FlowKind.Range,
props={'wl_range'}, props={'wl_range'},
) )
def compute_valid_wls_lazy(self, props) -> sp.Expr: def compute_valid_wls_lazy(self, props) -> sp.Expr:
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=props['wl_range'][0] / spu.nm, start=props['wl_range'][0] / spu.nm,
stop=props['wl_range'][0] / spu.nm, stop=props['wl_range'][0] / spu.nm,
steps=0, steps=0,

View File

@ -63,7 +63,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Freq Domain': { 'Freq Domain': {
'Freqs': sockets.ExprSocketDef( 'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz, default_unit=spux.THz,
default_min=374.7406, ## 800nm default_min=374.7406, ## 800nm
@ -73,7 +73,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
}, },
'Time Domain': { 'Time Domain': {
't Range': sockets.ExprSocketDef( 't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Time, physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond, default_unit=spu.picosecond,
default_min=0, default_min=0,
@ -119,7 +119,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
'Freqs', 'Freqs',
}, },
input_socket_kinds={ input_socket_kinds={
'Freqs': ct.FlowKind.LazyArrayRange, 'Freqs': ct.FlowKind.Range,
}, },
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={ scale_input_sockets={
@ -160,7 +160,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
't Stride', 't Stride',
}, },
input_socket_kinds={ input_socket_kinds={
't Range': ct.FlowKind.LazyArrayRange, 't Range': ct.FlowKind.Range,
}, },
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={ scale_input_sockets={

View File

@ -63,7 +63,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Freq Domain': { 'Freq Domain': {
'Freqs': sockets.ExprSocketDef( 'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz, default_unit=spux.THz,
default_min=374.7406, ## 800nm default_min=374.7406, ## 800nm
@ -73,7 +73,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
}, },
'Time Domain': { 'Time Domain': {
't Range': sockets.ExprSocketDef( 't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Time, physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond, default_unit=spu.picosecond,
default_min=0, default_min=0,
@ -137,7 +137,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
'Freqs', 'Freqs',
}, },
input_socket_kinds={ input_socket_kinds={
'Freqs': ct.FlowKind.LazyArrayRange, 'Freqs': ct.FlowKind.Range,
}, },
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={ scale_input_sockets={

View File

@ -58,7 +58,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
abs_min=0, abs_min=0,
), ),
'Freqs': sockets.ExprSocketDef( 'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz, default_unit=spux.THz,
default_min=374.7406, ## 800nm default_min=374.7406, ## 800nm
@ -87,7 +87,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
'Freqs', 'Freqs',
}, },
input_socket_kinds={ input_socket_kinds={
'Freqs': ct.FlowKind.LazyArrayRange, 'Freqs': ct.FlowKind.Range,
}, },
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={ scale_input_sockets={

View File

@ -14,16 +14,15 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# from . import file_exporters, viewer, web_exporters from . import file_exporters, viewer, web_exporters
from . import viewer, web_exporters
BL_REGISTER = [ BL_REGISTER = [
*viewer.BL_REGISTER, *viewer.BL_REGISTER,
# *file_exporters.BL_REGISTER, *file_exporters.BL_REGISTER,
*web_exporters.BL_REGISTER, *web_exporters.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**viewer.BL_NODES, **viewer.BL_NODES,
# **file_exporters.BL_NODES, **file_exporters.BL_NODES,
**web_exporters.BL_NODES, **web_exporters.BL_NODES,
} }

View File

@ -14,11 +14,15 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import json_file_exporter from . import data_file_exporter
# from . import json_file_exporter
BL_REGISTER = [ BL_REGISTER = [
*json_file_exporter.BL_REGISTER, *data_file_exporter.BL_REGISTER,
# *json_file_exporter.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**json_file_exporter.BL_NODES, **data_file_exporter.BL_NODES,
# **json_file_exporter.BL_NODES,
} }

View File

@ -0,0 +1,252 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from pathlib import Path
import bpy
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Operators
####################
class ExportDataFile(bpy.types.Operator):
"""Exports data from the input to `DataFileExporterNode` to the file path given on the same node, if the path is compatible with the chosen export format (a property on the node)."""
bl_idname = ct.OperatorType.NodeExportDataFile
bl_label = 'Save Data File'
bl_description = 'Save a file with the contents, name, and format indicated by a NodeExportDataFile'
@classmethod
def poll(cls, context):
return (
# Check Node
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and (node := context.node).node_type == ct.NodeType.DataFileExporter
# Check Expr
and node.is_file_path_compatible_with_export_format
)
def execute(self, context: bpy.types.Context):
node = context.node
node.export_format.saver(node.file_path, node.expr_data, node.expr_info)
return {'FINISHED'}
####################
# - Node
####################
class DataFileExporterNode(base.MaxwellSimNode):
# """Export input data to a supported
node_type = ct.NodeType.DataFileExporter
bl_label = 'Data File Importer'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'File Path': sockets.FilePathSocketDef(),
}
####################
# - Properties: Expr Info
####################
@events.on_value_changed(
socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_expr = not ct.FlowSignal.check(input_sockets['Expr'])
if has_expr:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property(depends_on={'file_path'})
def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the input expression's `InfoFlow`."""
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info
return None
@property
def expr_data(self) -> typ.Any | None:
"""Retrieve the input expression's data by evaluating its `Func`."""
func = self._compute_input('Expr', kind=ct.FlowKind.Func)
params = self._compute_input('Expr', kind=ct.FlowKind.Params)
has_func = not ct.FlowSignal.check(func)
has_params = not ct.FlowSignal.check(params)
if has_func and has_params:
symbol_values = {
sym.name: self._compute_input(sym.name, kind=ct.FlowKind.Value)
for sym in params.sorted_symbols
}
return func.func_jax(
*params.scaled_func_args(spux.UNITS_SI, symbol_values=symbol_values),
**params.scaled_func_kwargs(spux.UNITS_SI, symbol_values=symbol_values),
)
return None
####################
# - Properties: File Path
####################
@events.on_value_changed(
socket_name={'File Path'},
input_sockets={'File Path'},
input_socket_kinds={'File Path': ct.FlowKind.Value},
input_sockets_optional={'File Path': True},
)
def on_file_path_changed(self, input_sockets) -> None: # noqa: D102
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
if has_file_path:
self.file_path = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def file_path(self) -> Path:
"""Retrieve the input file path."""
file_path = self._compute_input(
'File Path', kind=ct.FlowKind.Value, optional=True
)
has_file_path = not ct.FlowSignal.check(file_path)
if has_file_path:
return file_path
return None
####################
# - Properties: Export Format
####################
export_format: ct.DataFileFormat = bl_cache.BLField(
enum_cb=lambda self, _: self.search_export_formats(),
cb_depends_on={'expr_info'},
)
def search_export_formats(self):
if self.expr_info is not None:
return [
data_file_format.bl_enum_element(i)
for i, data_file_format in enumerate(list(ct.DataFileFormat))
if data_file_format.is_info_compatible(self.expr_info)
]
return ct.DataFileFormat.bl_enum_elements()
####################
# - Properties: File Path Compatibility
####################
@bl_cache.cached_bl_property(depends_on={'file_path', 'export_format'})
def is_file_path_compatible_with_export_format(self) -> bool | None:
"""Determine whether the given file path is actually compatible with the desired export format."""
if self.file_path is not None and self.export_format is not None:
return self.export_format.is_path_compatible(self.file_path)
return None
####################
# - UI
####################
def draw_label(self):
"""Show the extracted file name (w/extension) in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
if self.file_path is not None:
return 'Save: ' + self.file_path.name
return self.bl_label
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Show information about the loaded file."""
if self.export_format is not None:
box = layout.box()
row = box.row()
row.alignment = 'CENTER'
row.label(text='Data File')
row = box.row()
row.alignment = 'CENTER'
row.label(text=self.file_path.name)
compatibility = self.is_file_path_compatible_with_export_format
if compatibility is not None:
row = box.row()
row.alignment = 'CENTER'
if compatibility:
row.label(text='Valid Path | Format', icon='CHECKMARK')
else:
row.label(text='Invalid Path | Format', icon='ERROR')
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['export_format'], text='')
layout.operator(ct.OperatorType.NodeExportDataFile, text='Save Data File')
####################
# - Events
####################
@events.on_value_changed(
# Trigger
socket_name='Expr',
run_on_init=True,
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Expr': True},
)
def on_expr_changed(self, input_sockets: dict) -> None:
"""Declare any loose input sockets needed to realize the input expr's symbols."""
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
# Provide Sockets for Symbol Realization
## -> Only happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols:
if set(self.loose_input_sockets) != {
dim.name for dim in params.symbols if dim in info.dims
}:
self.loose_input_sockets = {
dim_name: sockets.ExprSocketDef(**expr_info)
for dim_name, expr_info in params.sym_expr_infos(info).items()
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
####################
# - Blender Registration
####################
BL_REGISTER = [
ExportDataFile,
DataFileExporterNode,
]
BL_NODES = {
ct.NodeType.DataFileExporter: (ct.NodeCategory.MAXWELLSIM_OUTPUTS_FILEEXPORTERS)
}

View File

@ -74,7 +74,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
}, },
'Symbolic': { 'Symbolic': {
't Range': sockets.ExprSocketDef( 't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyArrayRange, active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Time, physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond, default_unit=spu.picosecond,
default_min=0, default_min=0,
@ -132,8 +132,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
'Envelope', 'Envelope',
}, },
input_socket_kinds={ input_socket_kinds={
't Range': ct.FlowKind.LazyArrayRange, 't Range': ct.FlowKind.Range,
'Envelope': ct.FlowKind.LazyValueFunc, 'Envelope': ct.FlowKind.Func,
}, },
input_sockets_optional={ input_sockets_optional={
'max E': True, 'max E': True,

View File

@ -525,9 +525,9 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Array", but socket does not define it' msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Array", but socket does not define it'
raise NotImplementedError(msg) raise NotImplementedError(msg)
# LazyValueFunc # Func
@property @property
def lazy_value_func(self) -> ct.LazyValueFuncFlow: def lazy_func(self) -> ct.FuncFlow:
"""Throws a descriptive error. """Throws a descriptive error.
Notes: Notes:
@ -538,8 +538,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
""" """
return ct.FlowSignal.NoFlow return ct.FlowSignal.NoFlow
@lazy_value_func.setter @lazy_func.setter
def lazy_value_func(self, lazy_value_func: ct.LazyValueFuncFlow) -> None: def lazy_func(self, lazy_func: ct.FuncFlow) -> None:
"""Throws a descriptive error. """Throws a descriptive error.
Notes: Notes:
@ -548,12 +548,12 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises: Raises:
NotImplementedError: When used without being overridden. NotImplementedError: When used without being overridden.
""" """
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.LazyValueFunc", but socket does not define it' msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Func", but socket does not define it'
raise NotImplementedError(msg) raise NotImplementedError(msg)
# LazyArrayRange # Range
@property @property
def lazy_array_range(self) -> ct.LazyArrayRangeFlow: def lazy_range(self) -> ct.RangeFlow:
"""Throws a descriptive error. """Throws a descriptive error.
Notes: Notes:
@ -564,8 +564,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
""" """
return ct.FlowSignal.NoFlow return ct.FlowSignal.NoFlow
@lazy_array_range.setter @lazy_range.setter
def lazy_array_range(self, value: ct.LazyArrayRangeFlow) -> None: def lazy_range(self, value: ct.RangeFlow) -> None:
"""Throws a descriptive error. """Throws a descriptive error.
Notes: Notes:
@ -574,7 +574,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises: Raises:
NotImplementedError: When used without being overridden. NotImplementedError: When used without being overridden.
""" """
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.LazyArrayRange", but socket does not define it' msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Range", but socket does not define it'
raise NotImplementedError(msg) raise NotImplementedError(msg)
#################### ####################
@ -595,8 +595,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
kind_data_map = { kind_data_map = {
ct.FlowKind.Value: lambda: self.value, ct.FlowKind.Value: lambda: self.value,
ct.FlowKind.Array: lambda: self.array, ct.FlowKind.Array: lambda: self.array,
ct.FlowKind.LazyValueFunc: lambda: self.lazy_value_func, ct.FlowKind.Func: lambda: self.lazy_func,
ct.FlowKind.LazyArrayRange: lambda: self.lazy_array_range, ct.FlowKind.Range: lambda: self.lazy_range,
ct.FlowKind.Params: lambda: self.params, ct.FlowKind.Params: lambda: self.params,
ct.FlowKind.Info: lambda: self.info, ct.FlowKind.Info: lambda: self.info,
} }
@ -783,8 +783,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
{ {
ct.FlowKind.Value: self.draw_value, ct.FlowKind.Value: self.draw_value,
ct.FlowKind.Array: self.draw_array, ct.FlowKind.Array: self.draw_array,
ct.FlowKind.LazyArrayRange: self.draw_lazy_array_range, ct.FlowKind.Range: self.draw_lazy_range,
ct.FlowKind.LazyValueFunc: self.draw_lazy_value_func, ct.FlowKind.Func: self.draw_lazy_func,
}[self.active_kind](col) }[self.active_kind](col)
# Info Drawing # Info Drawing
@ -894,11 +894,11 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
col: Target for defining UI elements. col: Target for defining UI elements.
""" """
def draw_lazy_array_range(self, col: bpy.types.UILayout) -> None: def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
"""Draws the socket lazy array range on its own line. """Draws the socket lazy array range on its own line.
Notes: Notes:
Should be overriden by individual socket classes, if they have an editable `FlowKind.LazyArrayRange`. Should be overriden by individual socket classes, if they have an editable `FlowKind.Range`.
Parameters: Parameters:
col: Target for defining UI elements. col: Target for defining UI elements.
@ -914,11 +914,11 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
col: Target for defining UI elements. col: Target for defining UI elements.
""" """
def draw_lazy_value_func(self, col: bpy.types.UILayout) -> None: def draw_lazy_func(self, col: bpy.types.UILayout) -> None:
"""Draws the socket lazy value function UI on its own line. """Draws the socket lazy value function UI on its own line.
Notes: Notes:
Should be overriden by individual socket classes, if they have an editable `FlowKind.LazyValueFunc`. Should be overriden by individual socket classes, if they have an editable `FlowKind.Func`.
Parameters: Parameters:
col: Target for defining UI elements. col: Target for defining UI elements.

View File

@ -100,7 +100,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
When active, `self.active_unit` can be used via the UI to select valid unit of the given `self.physical_type`, and `self.unit` works. When active, `self.active_unit` can be used via the UI to select valid unit of the given `self.physical_type`, and `self.unit` works.
The enum itself can be dynamically altered, ex. via its UI dropdown support. The enum itself can be dynamically altered, ex. via its UI dropdown support.
symbols: The symbolic variables valid in the context of the expression. symbols: The symbolic variables valid in the context of the expression.
Various features, including `LazyValueFunc` support, become available when symbols are in use. Various features, including `Func` support, become available when symbols are in use.
The presence of symbols forces fallback to a string-based `sympy` expression UI. The presence of symbols forces fallback to a string-based `sympy` expression UI.
active_unit: The currently active unit, as a dropdown. active_unit: The currently active unit, as a dropdown.
@ -118,13 +118,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
# Symbols # Symbols
# active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
symbols: frozenset[sp.Symbol] = bl_cache.BLField(frozenset()) sim_symbols.SimSymbolName.Expr
)
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
# @property @property
# def symbols(self) -> set[sp.Symbol]: def symbols(self) -> set[sp.Symbol]:
# """Current symbols as an unordered set.""" """Current symbols as an unordered set."""
# return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols} return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
@bl_cache.cached_bl_property(depends_on={'symbols'}) @bl_cache.cached_bl_property(depends_on={'symbols'})
def sorted_symbols(self) -> list[sp.Symbol]: def sorted_symbols(self) -> list[sp.Symbol]:
@ -169,7 +171,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4 ((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4
) )
# UI: LazyArrayRange # UI: Range
steps: int = bl_cache.BLField(2, soft_min=2, abs_min=0) steps: int = bl_cache.BLField(2, soft_min=2, abs_min=0)
scaling: ct.ScalingMode = bl_cache.BLField(ct.ScalingMode.Lin) scaling: ct.ScalingMode = bl_cache.BLField(ct.ScalingMode.Lin)
## Expression ## Expression
@ -184,6 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# UI: Info # UI: Info
show_func_ui: bool = bl_cache.BLField(True)
show_info_columns: bool = bl_cache.BLField(False) show_info_columns: bool = bl_cache.BLField(False)
info_columns: set[InfoDisplayCol] = bl_cache.BLField( info_columns: set[InfoDisplayCol] = bl_cache.BLField(
{InfoDisplayCol.Length, InfoDisplayCol.MathType} {InfoDisplayCol.Length, InfoDisplayCol.MathType}
@ -248,7 +251,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
and not self.symbols and not self.symbols
): ):
self.value = self.value.subs({self.unit: prev_unit}) self.value = self.value.subs({self.unit: prev_unit})
self.lazy_array_range = self.lazy_array_range.correct_unit(prev_unit) self.lazy_range = self.lazy_range.correct_unit(prev_unit)
self.prev_unit = self.active_unit self.prev_unit = self.active_unit
@ -454,20 +457,20 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
#################### ####################
# - FlowKind: LazyArrayRange # - FlowKind: Range
#################### ####################
@property @property
def lazy_array_range(self) -> ct.LazyArrayRangeFlow: def lazy_range(self) -> ct.RangeFlow:
"""Return the not-yet-computed uniform array defined by the socket. """Return the not-yet-computed uniform array defined by the socket.
Notes: Notes:
Called to compute the internal `FlowKind.LazyArrayRange` of this socket. Called to compute the internal `FlowKind.Range` of this socket.
Return: Return:
The range of lengths, which uses no symbols. The range of lengths, which uses no symbols.
""" """
if self.symbols: if self.symbols:
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=self.raw_min_sp, start=self.raw_min_sp,
stop=self.raw_max_sp, stop=self.raw_max_sp,
steps=self.steps, steps=self.steps,
@ -493,7 +496,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
], ],
}[self.mathtype]() }[self.mathtype]()
return ct.LazyArrayRangeFlow( return ct.RangeFlow(
start=min_bound, start=min_bound,
stop=max_bound, stop=max_bound,
steps=self.steps, steps=self.steps,
@ -501,12 +504,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
unit=self.unit, unit=self.unit,
) )
@lazy_array_range.setter @lazy_range.setter
def lazy_array_range(self, value: ct.LazyArrayRangeFlow) -> None: def lazy_range(self, value: ct.RangeFlow) -> None:
"""Set the not-yet-computed uniform array defined by the socket. """Set the not-yet-computed uniform array defined by the socket.
Notes: Notes:
Called to compute the internal `FlowKind.LazyArrayRange` of this socket. Called to compute the internal `FlowKind.Range` of this socket.
""" """
self.steps = value.steps self.steps = value.steps
self.scaling = value.scaling self.scaling = value.scaling
@ -544,20 +547,20 @@ class ExprBLSocket(base.MaxwellSimSocket):
] ]
#################### ####################
# - FlowKind: LazyValueFunc (w/Params if Constant) # - FlowKind: Func (w/Params if Constant)
#################### ####################
@property @property
def lazy_value_func(self) -> ct.LazyValueFuncFlow: def lazy_func(self) -> ct.FuncFlow:
"""Returns a lazy value that computes the expression returned by `self.value`. """Returns a lazy value that computes the expression returned by `self.value`.
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `LazyValueFuncFlow`. If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`. Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
""" """
# Symbolic # Symbolic
## -> `self.value` is guaranteed to be an expression with unknowns. ## -> `self.value` is guaranteed to be an expression with unknowns.
## -> The function computes `self.value` with unknowns as arguments. ## -> The function computes `self.value` with unknowns as arguments.
if self.symbols: if self.symbols:
return ct.LazyValueFuncFlow( return ct.FuncFlow(
func=sp.lambdify( func=sp.lambdify(
self.sorted_symbols, self.sorted_symbols,
spux.scale_to_unit(self.value, self.unit), spux.scale_to_unit(self.value, self.unit),
@ -572,7 +575,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
## -> ("Dummy" as in returns the same argument that it takes). ## -> ("Dummy" as in returns the same argument that it takes).
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim. ## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
## -> Generally only useful for operations with other expressions. ## -> Generally only useful for operations with other expressions.
return ct.LazyValueFuncFlow( return ct.FuncFlow(
func=lambda v: v, func=lambda v: v,
func_args=[ func_args=[
self.physical_type if self.physical_type is not None else self.mathtype self.physical_type if self.physical_type is not None else self.mathtype
@ -582,7 +585,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
@property @property
def params(self) -> ct.ParamsFlow: def params(self) -> ct.ParamsFlow:
"""Returns parameter symbols/values to accompany `self.lazy_value_func`. """Returns parameter symbols/values to accompany `self.lazy_func`.
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use). If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`. Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
@ -605,45 +608,34 @@ class ExprBLSocket(base.MaxwellSimSocket):
@property @property
def info(self) -> ct.ArrayFlow: def info(self) -> ct.ArrayFlow:
r"""Returns parameter symbols/values to accompany `self.lazy_value_func`. r"""Returns parameter symbols/values to accompany `self.lazy_func`.
The output name/size/mathtype/unit corresponds directly the `ExprSocket`. The output name/size/mathtype/unit corresponds directly the `ExprSocket`.
If `self.symbols` has entries, then these will propagate as dimensions with unresolvable `LazyArrayRangeFlow` index descriptions. If `self.symbols` has entries, then these will propagate as dimensions with unresolvable `RangeFlow` index descriptions.
The index range will be $(-\infty,\infty)$, with $0$ steps and no unit. The index range will be $(-\infty,\infty)$, with $0$ steps and no unit.
The order/naming matches `self.params` and `self.lazy_value_func`. The order/naming matches `self.params` and `self.lazy_func`.
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
""" """
output_sim_sym = (
sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
),
)
if self.symbols: if self.symbols:
return ct.InfoFlow( return ct.InfoFlow(
dim_names=[sym.name for sym in self.sorted_symbols], dims={sim_sym: None for sim_sym in self.active_symbols},
dim_idx={ output=output_sim_sym,
sym.name: ct.LazyArrayRangeFlow(
start=-sp.oo if _check_sym_oo(sym) else -sp.zoo,
stop=sp.oo if _check_sym_oo(sym) else sp.zoo,
steps=0,
unit=None, ## Symbols alone are unitless.
)
## TODO: PhysicalTypes for symbols? Or nah?
## TODO: Can we parse some sp.Interval for explicit domains?
## -> We investigated sp.Symbol(..., domain=...).
## -> It's no good. We can't re-extract the interval given to domain.
for sym in self.sorted_symbols
},
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
) )
# Constant # Constant
return ct.InfoFlow( return ct.InfoFlow(output=output_sim_sym)
output_name='_', ## Use node:socket name? Or something? Ahh
output_shape=self.size.shape,
output_mathtype=self.mathtype,
output_unit=self.unit,
)
#################### ####################
# - FlowKind: Capabilities # - FlowKind: Capabilities
@ -805,13 +797,13 @@ class ExprBLSocket(base.MaxwellSimSocket):
for sym in self.symbols: for sym in self.symbols:
col.label(text=spux.pretty_symbol(sym)) col.label(text=spux.pretty_symbol(sym))
def draw_lazy_array_range(self, col: bpy.types.UILayout) -> None: def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
"""Draw the socket body for a simple, uniform range of values between two values/expressions. """Draw the socket body for a simple, uniform range of values between two values/expressions.
Drawn when `self.active_kind == FlowKind.LazyArrayRange`. Drawn when `self.active_kind == FlowKind.Range`.
Notes: Notes:
If `self.steps == 0`, then the `LazyArrayRange` is considered to have a to-be-determined number of steps. If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
As such, `self.steps` won't be exposed in the UI. As such, `self.steps` won't be exposed in the UI.
""" """
if self.symbols: if self.symbols:
@ -835,44 +827,49 @@ class ExprBLSocket(base.MaxwellSimSocket):
if self.steps != 0: if self.steps != 0:
col.prop(self, self.blfields['steps'], text='') col.prop(self, self.blfields['steps'], text='')
def draw_lazy_value_func(self, col: bpy.types.UILayout) -> None: def draw_lazy_func(self, col: bpy.types.UILayout) -> None:
"""Draw the socket body for a single flexible value/expression, for down-chain lazy evaluation. """Draw the socket body for a single flexible value/expression, for down-chain lazy evaluation.
This implements the most flexible variant of the `ExprSocket` UI, providing the user with full runtime-configuration of the exact `self.size`, `self.mathtype`, `self.physical_type`, and `self.symbols` of the expression. This implements the most flexible variant of the `ExprSocket` UI, providing the user with full runtime-configuration of the exact `self.size`, `self.mathtype`, `self.physical_type`, and `self.symbols` of the expression.
Notes: Notes:
Drawn when `self.active_kind == FlowKind.LazyValueFunc`. Drawn when `self.active_kind == FlowKind.Func`.
This is an ideal choice for ex. math nodes that need to accept arbitrary expressions as inputs, with an eye towards lazy evaluation of ex. symbolic terms. This is an ideal choice for ex. math nodes that need to accept arbitrary expressions as inputs, with an eye towards lazy evaluation of ex. symbolic terms.
Uses `draw_value` to draw the base UI Uses `draw_value` to draw the base UI
""" """
# Physical Type Selector if self.show_func_ui:
## -> Determines whether/which unit-dropdown will be shown. # Output Name Selector
col.prop(self, self.blfields['physical_type'], text='') ## -> The name of the output
col.prop(self, self.blfields['output_name'], text='')
# Non-Symbolic: Size/Mathtype Selector # Physical Type Selector
## -> Symbols imply str expr input. ## -> Determines whether/which unit-dropdown will be shown.
## -> For arbitrary str exprs, size/mathtype are derived from the expr. col.prop(self, self.blfields['physical_type'], text='')
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
if not self.symbols:
row = col.row(align=True)
row.prop(self, self.blfields['size'], text='')
row.prop(self, self.blfields['mathtype'], text='')
# Base UI # Non-Symbolic: Size/Mathtype Selector
## -> Draws the UI appropriate for the above choice of constraints. ## -> Symbols imply str expr input.
self.draw_value(col) ## -> For arbitrary str exprs, size/mathtype are derived from the expr.
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
if not self.symbols:
row = col.row(align=True)
row.prop(self, self.blfields['size'], text='')
row.prop(self, self.blfields['mathtype'], text='')
# Symbol UI # Base UI
## -> Draws the UI appropriate for the above choice of constraints. ## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO self.draw_value(col)
# Symbol UI
## -> Draws the UI appropriate for the above choice of constraints.
## -> TODO
#################### ####################
# - UI: InfoFlow # - UI: InfoFlow
#################### ####################
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None: def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
if self.active_kind == ct.FlowKind.LazyValueFunc and self.show_info_columns: if self.active_kind == ct.FlowKind.Func and self.show_info_columns:
row = col.row() row = col.row()
box = row.box() box = row.box()
grid = box.grid_flow( grid = box.grid_flow(
@ -884,9 +881,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
) )
# Dimensions # Dimensions
for dim_name in info.dim_names: for dim in info.dims:
dim_idx = info.dim_idx[dim_name] dim_idx = info.dims[dim]
grid.label(text=dim_name) grid.label(text=dim.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text=str(len(dim_idx))) grid.label(text=str(len(dim_idx)))
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
@ -895,27 +892,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
grid.label(text=spux.sp_to_str(dim_idx.unit)) grid.label(text=spux.sp_to_str(dim_idx.unit))
# Outputs # Outputs
grid.label(text=info.output_name) grid.label(text=info.output.name_pretty)
if InfoDisplayCol.Length in self.info_columns: if InfoDisplayCol.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput) grid.label(text='', icon=ct.Icon.DataSocketOutput)
if InfoDisplayCol.MathType in self.info_columns: if InfoDisplayCol.MathType in self.info_columns:
grid.label( grid.label(
text=( text=(
spux.MathType.to_str(info.output_mathtype) spux.MathType.to_str(info.output.mathtype)
+ ( + (
'ˣ'.join( 'ˣ'.join(
[ [
unicode_superscript(out_axis) unicode_superscript(out_axis)
for out_axis in info.output_shape for out_axis in info.output.shape
] ]
) )
if info.output_shape if info.output.shape
else '' else ''
) )
) )
) )
if InfoDisplayCol.Unit in self.info_columns: if InfoDisplayCol.Unit in self.info_columns:
grid.label(text=f'{spux.sp_to_str(info.output_unit)}') grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
#################### ####################
@ -925,10 +922,11 @@ class ExprSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.Expr socket_type: ct.SocketType = ct.SocketType.Expr
active_kind: typ.Literal[ active_kind: typ.Literal[
ct.FlowKind.Value, ct.FlowKind.Value,
ct.FlowKind.LazyArrayRange, ct.FlowKind.Range,
ct.FlowKind.Array, ct.FlowKind.Array,
ct.FlowKind.LazyValueFunc, ct.FlowKind.Func,
] = ct.FlowKind.Value ] = ct.FlowKind.Value
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
# Socket Interface # Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar size: spux.NumberSize1D = spux.NumberSize1D.Scalar
@ -938,22 +936,19 @@ class ExprSocketDef(base.SocketDef):
default_unit: spux.Unit | None = None default_unit: spux.Unit | None = None
default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list) default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
@property
def symbols(self) -> set[sp.Symbol]:
return {sim_symbol.sp_symbol for sim_symbol in self.default_symbols}
# FlowKind: Value # FlowKind: Value
default_value: spux.SympyExpr = 0 default_value: spux.SympyExpr = 0
abs_min: spux.SympyExpr | None = None abs_min: spux.SympyExpr | None = None
abs_max: spux.SympyExpr | None = None abs_max: spux.SympyExpr | None = None
# FlowKind: LazyArrayRange # FlowKind: Range
default_min: spux.SympyExpr = 0 default_min: spux.SympyExpr = 0
default_max: spux.SympyExpr = 1 default_max: spux.SympyExpr = 1
default_steps: int = 2 default_steps: int = 2
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# UI # UI
show_func_ui: bool = True
show_info_columns: bool = False show_info_columns: bool = False
#################### ####################
@ -1107,7 +1102,7 @@ class ExprSocketDef(base.SocketDef):
return self return self
#################### ####################
# - Parse FlowKind.LazyArrayRange # - Parse FlowKind.Range
#################### ####################
@pyd.field_validator('default_steps') @pyd.field_validator('default_steps')
@classmethod @classmethod
@ -1120,8 +1115,8 @@ class ExprSocketDef(base.SocketDef):
return v return v
@pyd.model_validator(mode='after') @pyd.model_validator(mode='after')
def parse_default_lazy_array_range_numbers(self) -> typ.Self: def parse_default_lazy_range_numbers(self) -> typ.Self:
"""Guarantees that the default `ct.LazyArrayRange` bounds are sympy expressions. """Guarantees that the default `ct.Range` bounds are sympy expressions.
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`. If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`.
@ -1150,9 +1145,17 @@ class ExprSocketDef(base.SocketDef):
if mathtype_guide == 'expr': if mathtype_guide == 'expr':
dv_mathtype = spux.MathType.from_expr(bound) dv_mathtype = spux.MathType.from_expr(bound)
if not self.mathtype.is_compatible(dv_mathtype): if not self.mathtype.is_compatible(dv_mathtype):
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default LazyArrayRange min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}' msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg) raise ValueError(msg)
# Coerce from Infinite
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None: if new_bounds[0] is not None:
self.default_min = new_bounds[0] self.default_min = new_bounds[0]
if new_bounds[1] is not None: if new_bounds[1] is not None:
@ -1161,8 +1164,8 @@ class ExprSocketDef(base.SocketDef):
return self return self
@pyd.model_validator(mode='after') @pyd.model_validator(mode='after')
def parse_default_lazy_array_range_size(self) -> typ.Self: def parse_default_lazy_range_size(self) -> typ.Self:
"""Guarantees that the default `ct.LazyArrayRange` bounds are unshaped. """Guarantees that the default `ct.Range` bounds are unshaped.
Raises: Raises:
ValueError: If `self.default_min` or `self.default_max` are shaped. ValueError: If `self.default_min` or `self.default_max` are shaped.
@ -1170,16 +1173,16 @@ class ExprSocketDef(base.SocketDef):
# Check ActiveKind and Size # Check ActiveKind and Size
## -> NOTE: This doesn't protect against dynamic changes to either. ## -> NOTE: This doesn't protect against dynamic changes to either.
if ( if (
self.active_kind == ct.FlowKind.LazyArrayRange self.active_kind == ct.FlowKind.Range
and self.size is not spux.NumberSize1D.Scalar and self.size is not spux.NumberSize1D.Scalar
): ):
msg = "Can't have a non-Scalar size when LazyArrayRange is set as the active kind." msg = "Can't have a non-Scalar size when Range is set as the active kind."
raise ValueError(msg) raise ValueError(msg)
# Check that Bounds are Shapeless # Check that Bounds are Shapeless
for bound in [self.default_min, self.default_max]: for bound in [self.default_min, self.default_max]:
if hasattr(bound, 'shape'): if hasattr(bound, 'shape'):
msg = f'ExprSocket: A default bound {bound} (type {type(bound)}) has a shape, but LazyArrayRange supports no shape in ExprSockets.' msg = f'ExprSocket: A default bound {bound} (type {type(bound)}) has a shape, but Range supports no shape in ExprSockets.'
raise ValueError(msg) raise ValueError(msg)
return self return self
@ -1217,13 +1220,14 @@ class ExprSocketDef(base.SocketDef):
#################### ####################
def init(self, bl_socket: ExprBLSocket) -> None: def init(self, bl_socket: ExprBLSocket) -> None:
bl_socket.active_kind = self.active_kind bl_socket.active_kind = self.active_kind
bl_socket.output_name = self.output_name
# Socket Interface # Socket Interface
## -> Recall that auto-updates are turned off during init() ## -> Recall that auto-updates are turned off during init()
bl_socket.size = self.size bl_socket.size = self.size
bl_socket.mathtype = self.mathtype bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type bl_socket.physical_type = self.physical_type
bl_socket.symbols = self.symbols bl_socket.active_symbols = self.symbols
# FlowKind.Value # FlowKind.Value
## -> We must take units into account when setting bl_socket.value ## -> We must take units into account when setting bl_socket.value
@ -1235,9 +1239,9 @@ class ExprSocketDef(base.SocketDef):
bl_socket.prev_unit = bl_socket.active_unit bl_socket.prev_unit = bl_socket.active_unit
# FlowKind.LazyArrayRange # FlowKind.Range
## -> We can directly pass None to unit. ## -> We can directly pass None to unit.
bl_socket.lazy_array_range = ct.LazyArrayRangeFlow( bl_socket.lazy_range = ct.RangeFlow(
start=self.default_min, start=self.default_min,
stop=self.default_max, stop=self.default_max,
steps=self.default_steps, steps=self.default_steps,
@ -1246,6 +1250,7 @@ class ExprSocketDef(base.SocketDef):
) )
# UI # UI
bl_socket.show_func_ui = self.show_func_ui
bl_socket.show_info_columns = self.show_info_columns bl_socket.show_info_columns = self.show_info_columns
# Info Draw # Info Draw

View File

@ -61,7 +61,6 @@ SympyType = (
class MathType(enum.StrEnum): class MathType(enum.StrEnum):
"""Type identifiers that encompass common sets of mathematical objects.""" """Type identifiers that encompass common sets of mathematical objects."""
Bool = enum.auto()
Integer = enum.auto() Integer = enum.auto()
Rational = enum.auto() Rational = enum.auto()
Real = enum.auto() Real = enum.auto()
@ -77,8 +76,6 @@ class MathType(enum.StrEnum):
return MathType.Rational return MathType.Rational
if MathType.Integer in mathtypes: if MathType.Integer in mathtypes:
return MathType.Integer return MathType.Integer
if MathType.Bool in mathtypes:
return MathType.Bool
msg = f"Can't combine mathtypes {mathtypes}" msg = f"Can't combine mathtypes {mathtypes}"
raise ValueError(msg) raise ValueError(msg)
@ -88,7 +85,6 @@ class MathType(enum.StrEnum):
return ( return (
other other
in { in {
MT.Bool: [MT.Bool],
MT.Integer: [MT.Integer], MT.Integer: [MT.Integer],
MT.Rational: [MT.Integer, MT.Rational], MT.Rational: [MT.Integer, MT.Rational],
MT.Real: [MT.Integer, MT.Rational, MT.Real], MT.Real: [MT.Integer, MT.Rational, MT.Real],
@ -98,11 +94,9 @@ class MathType(enum.StrEnum):
def coerce_compatible_pyobj( def coerce_compatible_pyobj(
self, pyobj: bool | int | Fraction | float | complex self, pyobj: bool | int | Fraction | float | complex
) -> bool | int | Fraction | float | complex: ) -> int | Fraction | float | complex:
MT = MathType MT = MathType
match self: match self:
case MT.Bool:
return pyobj
case MT.Integer: case MT.Integer:
return int(pyobj) return int(pyobj)
case MT.Rational if isinstance(pyobj, int): case MT.Rational if isinstance(pyobj, int):
@ -123,8 +117,6 @@ class MathType(enum.StrEnum):
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)] *[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
) )
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
return MathType.Bool
if sp_obj.is_integer: if sp_obj.is_integer:
return MathType.Integer return MathType.Integer
if sp_obj.is_rational: if sp_obj.is_rational:
@ -146,7 +138,6 @@ class MathType(enum.StrEnum):
@staticmethod @staticmethod
def from_pytype(dtype: type) -> type: def from_pytype(dtype: type) -> type:
return { return {
bool: MathType.Bool,
int: MathType.Integer, int: MathType.Integer,
Fraction: MathType.Rational, Fraction: MathType.Rational,
float: MathType.Real, float: MathType.Real,
@ -166,7 +157,6 @@ class MathType(enum.StrEnum):
def pytype(self) -> type: def pytype(self) -> type:
MT = MathType MT = MathType
return { return {
MT.Bool: bool,
MT.Integer: int, MT.Integer: int,
MT.Rational: Fraction, MT.Rational: Fraction,
MT.Real: float, MT.Real: float,
@ -177,17 +167,25 @@ class MathType(enum.StrEnum):
def symbolic_set(self) -> type: def symbolic_set(self) -> type:
MT = MathType MT = MathType
return { return {
MT.Bool: sp.Set([sp.S(False), sp.S(True)]),
MT.Integer: sp.Integers, MT.Integer: sp.Integers,
MT.Rational: sp.Rationals, MT.Rational: sp.Rationals,
MT.Real: sp.Reals, MT.Real: sp.Reals,
MT.Complex: sp.Complexes, MT.Complex: sp.Complexes,
}[self] }[self]
@property
def sp_symbol_a(self) -> type:
MT = MathType
return {
MT.Integer: sp.Symbol('a', integer=True),
MT.Rational: sp.Symbol('a', rational=True),
MT.Real: sp.Symbol('a', real=True),
MT.Complex: sp.Symbol('a', complex=True),
}[self]
@staticmethod @staticmethod
def to_str(value: typ.Self) -> type: def to_str(value: typ.Self) -> type:
return { return {
MathType.Bool: 'T|F',
MathType.Integer: '', MathType.Integer: '',
MathType.Rational: '', MathType.Rational: '',
MathType.Real: '', MathType.Real: '',
@ -212,6 +210,9 @@ class MathType(enum.StrEnum):
) )
####################
# - Size: 1D
####################
class NumberSize1D(enum.StrEnum): class NumberSize1D(enum.StrEnum):
"""Valid 1D-constrained shape.""" """Valid 1D-constrained shape."""
@ -278,6 +279,20 @@ class NumberSize1D(enum.StrEnum):
(4, 1): NS.Vec4, (4, 1): NS.Vec4,
}[shape] }[shape]
@property
def rows(self):
NS = NumberSize1D
return {
NS.Scalar: 1,
NS.Vec2: 2,
NS.Vec3: 3,
NS.Vec4: 4,
}[self]
@property
def cols(self):
return 1
@property @property
def shape(self): def shape(self):
NS = NumberSize1D NS = NumberSize1D
@ -297,6 +312,30 @@ def symbol_range(sym: sp.Symbol) -> str:
) )
####################
# - Symbol Sizes
####################
class SimpleSize2D(enum.StrEnum):
"""Simple subset of sizes for rank-2 tensors."""
Scalar = enum.auto()
# Vectors
Vec2 = enum.auto() ## 2x1
Vec3 = enum.auto() ## 3x1
Vec4 = enum.auto() ## 4x1
# Covectors
CoVec2 = enum.auto() ## 1x2
CoVec3 = enum.auto() ## 1x3
CoVec4 = enum.auto() ## 1x4
# Square Matrices
Mat22 = enum.auto() ## 2x2
Mat33 = enum.auto() ## 3x3
Mat44 = enum.auto() ## 4x4
#################### ####################
# - Unit Dimensions # - Unit Dimensions
#################### ####################
@ -382,6 +421,8 @@ UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} } | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
#################### ####################
# - Expr Analysis: Units # - Expr Analysis: Units
@ -907,10 +948,6 @@ class PhysicalType(enum.StrEnum):
LumIntensity = enum.auto() LumIntensity = enum.auto()
LumFlux = enum.auto() LumFlux = enum.auto()
Illuminance = enum.auto() Illuminance = enum.auto()
# Optics
OrdinaryWaveVector = enum.auto()
AngularWaveVector = enum.auto()
PoyntingVector = enum.auto()
@functools.cached_property @functools.cached_property
def unit_dim(self): def unit_dim(self):
@ -956,10 +993,6 @@ class PhysicalType(enum.StrEnum):
PT.LumIntensity: Dims.luminous_intensity, PT.LumIntensity: Dims.luminous_intensity,
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
PT.Illuminance: Dims.luminous_intensity / Dims.length**2, PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
# Optics
PT.OrdinaryWaveVector: Dims.frequency,
PT.AngularWaveVector: Dims.angle * Dims.frequency,
PT.PoyntingVector: Dims.power / Dims.length**2,
}[self] }[self]
@functools.cached_property @functools.cached_property
@ -1196,10 +1229,6 @@ class PhysicalType(enum.StrEnum):
PT.HField: [None, (2,), (3,)], PT.HField: [None, (2,), (3,)],
# Luminal # Luminal
PT.LumFlux: [None, (2,), (3,)], PT.LumFlux: [None, (2,), (3,)],
# Optics
PT.OrdinaryWaveVector: [None, (2,), (3,)],
PT.AngularWaveVector: [None, (2,), (3,)],
PT.PoyntingVector: [None, (2,), (3,)],
} }
return overrides.get(self, [None]) return overrides.get(self, [None])
@ -1222,7 +1251,6 @@ class PhysicalType(enum.StrEnum):
- **Charge**: Generally, it is real. - **Charge**: Generally, it is real.
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787> However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
- **Poynting**: The imaginary part represents the oscillation in the power flux over time.
""" """
MT = MathType MT = MathType
@ -1249,10 +1277,6 @@ class PhysicalType(enum.StrEnum):
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
# Luminal # Luminal
# Optics
PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power
} }
return overrides.get(self, [MT.Real]) return overrides.get(self, [MT.Real])
@ -1323,10 +1347,6 @@ UNITS_SI: UnitSystem = {
_PT.LumIntensity: spu.candela, _PT.LumIntensity: spu.candela,
_PT.LumFlux: lumen, _PT.LumFlux: lumen,
_PT.Illuminance: spu.lux, _PT.Illuminance: spu.lux,
# Optics
_PT.OrdinaryWaveVector: spu.hertz,
_PT.AngularWaveVector: spu.radian * spu.hertz,
_PT.PoyntingVector: spu.watt / spu.meter**2,
} }
@ -1380,15 +1400,20 @@ def sympy_to_python(
#################### ####################
# - Convert to Unit System # - Convert to Unit System
#################### ####################
def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: def convert_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem | None
) -> SympyExpr:
"""Convert an expression to the units of a given unit system, with appropriate scaling.""" """Convert an expression to the units of a given unit system, with appropriate scaling."""
if unit_system is None:
return sp_obj
return spu.convert_to( return spu.convert_to(
sp_obj, sp_obj,
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
) )
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
"""Strip units occurring in the given unit system from the expression. """Strip units occurring in the given unit system from the expression.
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
@ -1397,11 +1422,13 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
Notes: Notes:
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
""" """
if unit_system is None:
return sp_obj.subs(UNIT_TO_1)
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
def scale_to_unit_system( def scale_to_unit_system(
sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
) -> int | float | complex | tuple | jax.Array: ) -> int | float | complex | tuple | jax.Array:
"""Convert an expression to the units of a given unit system, then strip all units of the unit system. """Convert an expression to the units of a given unit system, then strip all units of the unit system.

View File

@ -29,11 +29,13 @@ import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg import matplotlib.backends.backend_agg
import matplotlib.figure import matplotlib.figure
import matplotlib.style as mplstyle import matplotlib.style as mplstyle
import seaborn as sns
from blender_maxwell import contracts as ct from blender_maxwell import contracts as ct
from blender_maxwell.utils import logger from blender_maxwell.utils import logger
mplstyle.use('fast') ## TODO: Does this do anything? mplstyle.use('fast') ## TODO: Does this do anything?
sns.set_theme()
log = logger.get(__name__) log = logger.get(__name__)
@ -149,125 +151,98 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
#################### ####################
# - Plotters # - Plotters
#################### ####################
# () ->
def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None:
y_name = info.output_name
y_unit = info.output_unit
ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# () -> # () ->
def plot_box_plot_1d( def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.last_dim
y_name = info.output_name y_sym = info.output
y_unit = info.output_unit
ax.boxplot(data) ax.boxplot([data])
ax.set_title('Box Plot') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}') ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_sym = info.last_dim
y_sym = info.output
p = ax.bar(info.dims[x_sym], data)
ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
# () -> # () ->
def plot_curve_2d( def plot_curve_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
times = [time.perf_counter()] x_sym = info.last_dim
y_sym = info.output
x_name = info.dim_names[0] ax.plot(info.dims[x_sym].realize_array.values, data)
x_unit = info.dim_units[x_name] ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
y_name = info.output_name ax.set_xlabel(x_sym.plot_label)
y_unit = info.output_unit ax.set_xlabel(y_sym.plot_label)
times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data)
times.append(time.perf_counter() - times[0])
ax.set_title('2D Curve')
times.append(time.perf_counter() - times[0])
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
times.append(time.perf_counter() - times[0])
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
times.append(time.perf_counter() - times[0])
# log.critical('Timing of Curve2D: %s', str(times))
def plot_points_2d( def plot_points_2d(
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.last_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) ax.scatter(x_sym.realize_array.values, data)
ax.set_title('2D Points') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_name
y_unit = info.output_unit
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
# (, ) -> # (, ) ->
def plot_curves_2d( def plot_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
for category in range(data.shape[1]): for i, category in enumerate(info.dims[info.last_dim]):
ax.plot(info.dim_idx_arrays[0], data[:, category]) ax.plot(info.dims[x_sym], data[:, i], label=category)
ax.set_title('2D Curves') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
ax.legend() ax.legend()
def plot_filled_curves_2d( def plot_filled_curves_2d(
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.output
y_name = info.output_name
y_unit = info.output_unit
shared_x_idx = info.dim_idx_arrays[0] shared_x_idx = info.dims[info.last_dim]
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1]) ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
ax.set_title('2D Filled Curves') ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(y_sym.plot_label)
ax.legend()
# (, ) -> # (, ) ->
def plot_heatmap_2d( def plot_heatmap_2d(
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_sym = info.first_dim
x_unit = info.dim_units[x_name] y_sym = info.last_dim
y_name = info.dim_names[1] c_sym = info.output
y_unit = info.dim_units[y_name]
heatmap = ax.imshow(data, aspect='auto', interpolation='none') heatmap = ax.imshow(data, aspect='equal', interpolation='none')
# ax.figure.colorbar(heatmap, ax=ax) ax.figure.colorbar(heatmap, cax=ax)
ax.set_title('Heatmap')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()

View File

@ -18,26 +18,67 @@ import dataclasses
import enum import enum
import sys import sys
import typing as typ import typing as typ
from fractions import Fraction
import sympy as sp import sympy as sp
from . import extra_sympy_units as spux from . import extra_sympy_units as spux
int_min = -(2**64)
int_max = 2**64
float_min = sys.float_info.min
float_max = sys.float_info.max
#################### ####################
# - Simulation Symbols # - Simulation Symbol Names
#################### ####################
class SimSymbolName(enum.StrEnum): class SimSymbolName(enum.StrEnum):
# Lower
LowerA = enum.auto() LowerA = enum.auto()
LowerB = enum.auto()
LowerC = enum.auto()
LowerD = enum.auto()
LowerI = enum.auto()
LowerT = enum.auto() LowerT = enum.auto()
LowerX = enum.auto() LowerX = enum.auto()
LowerY = enum.auto() LowerY = enum.auto()
LowerZ = enum.auto() LowerZ = enum.auto()
# Physics # Fields
Ex = enum.auto()
Ey = enum.auto()
Ez = enum.auto()
Hx = enum.auto()
Hy = enum.auto()
Hz = enum.auto()
Er = enum.auto()
Etheta = enum.auto()
Ephi = enum.auto()
Hr = enum.auto()
Htheta = enum.auto()
Hphi = enum.auto()
# Optics
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
Flux = enum.auto()
PermXX = enum.auto()
PermYY = enum.auto()
PermZZ = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
# Generic
Expr = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def to_name(v: typ.Self) -> str: def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name. """Convert the enum value to a human-friendly name.
@ -50,27 +91,6 @@ class SimSymbolName(enum.StrEnum):
""" """
return SimSymbolName(v).name return SimSymbolName(v).name
@property
def name(self) -> str:
SSN = SimSymbolName
return {
SSN.LowerA: 'a',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
@staticmethod @staticmethod
def to_icon(_: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
"""Convert the enum value to a Blender icon. """Convert the enum value to a Blender icon.
@ -83,6 +103,75 @@ class SimSymbolName(enum.StrEnum):
""" """
return '' return ''
####################
# - Computed Properties
####################
@property
def name(self) -> str:
SSN = SimSymbolName
return {
# Lower
SSN.LowerA: 'a',
SSN.LowerB: 'b',
SSN.LowerC: 'c',
SSN.LowerD: 'd',
SSN.LowerI: 'i',
SSN.LowerT: 't',
SSN.LowerX: 'x',
SSN.LowerY: 'y',
SSN.LowerZ: 'z',
# Fields
SSN.Ex: 'Ex',
SSN.Ey: 'Ey',
SSN.Ez: 'Ez',
SSN.Hx: 'Hx',
SSN.Hy: 'Hy',
SSN.Hz: 'Hz',
SSN.Er: 'Ex',
SSN.Etheta: 'Ey',
SSN.Ephi: 'Ez',
SSN.Hr: 'Hx',
SSN.Htheta: 'Hy',
SSN.Hphi: 'Hz',
# Optics
SSN.Wavelength: 'wl',
SSN.Frequency: 'freq',
SSN.Flux: 'flux',
SSN.PermXX: 'eps_xx',
SSN.PermYY: 'eps_yy',
SSN.PermZZ: 'eps_zz',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
# Generic
SSN.Expr: 'expr',
}[self]
@property
def name_pretty(self) -> str:
SSN = SimSymbolName
return {
SSN.Wavelength: 'λ',
SSN.Frequency: '𝑓',
}.get(self, self.name)
####################
# - Simulation Symbol
####################
def mk_interval(
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
interval_inf: tuple[bool, bool],
interval_closed: tuple[bool, bool],
unit_factor: typ.Literal[1] | spux.Unit,
) -> sp.Interval:
"""Create a symbolic interval from the tuples (and unit) defining it."""
return sp.Interval(
start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
left_open=(True if interval_inf[0] else not interval_closed[0]),
right_open=(True if interval_inf[1] else not interval_closed[1]),
)
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
class SimSymbol: class SimSymbol:
@ -94,66 +183,145 @@ class SimSymbol:
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
""" """
sim_node_name: SimSymbolName = SimSymbolName.LowerX sym_name: SimSymbolName
mathtype: spux.MathType = spux.MathType.Real mathtype: spux.MathType = spux.MathType.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
## TODO: Shape/size support? Incl. MatrixSymbol. # Units
## -> 'None' indicates that no particular unit has yet been chosen.
## -> Not exposed in the UI; must be set some other way.
unit: spux.Unit | None = None
# Domain # Size
interval_finite: tuple[float, float] = (0, 1) ## -> All SimSymbol sizes are "2D", but interpreted by convention.
## -> 1x1: "Scalar".
## -> nx1: "Vector".
## -> 1xn: "Covector".
## -> nxn: "Matrix".
rows: int = 1
cols: int = 1
# Scalar Domain: "Interval"
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
## -> See self.domain.
## -> We have to deconstruct symbolic interval semantics a bit for UI.
interval_finite_z: tuple[int, int] = (0, 1)
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
interval_finite_re: tuple[float, float] = (0, 1)
interval_inf: tuple[bool, bool] = (True, True) interval_inf: tuple[bool, bool] = (True, True)
interval_closed: tuple[bool, bool] = (False, False) interval_closed: tuple[bool, bool] = (False, False)
interval_finite_im: tuple[float, float] = (0, 1)
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
#################### ####################
# - Properties # - Properties
#################### ####################
@property @property
def name(self) -> str: def name(self) -> str:
return self.sim_node_name.name """Usable name for the symbol."""
return self.sym_name.name
@property
def name_pretty(self) -> str:
"""Pretty (possibly unicode) name for the thing."""
return self.sym_name.name_pretty
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
@property
def plot_label(self) -> str:
"""Pretty plot-oriented label."""
return f'{self.name_pretty}' + (
f'({self.unit})' if self.unit is not None else ''
)
@property
def unit_factor(self) -> spux.SympyExpr:
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
return self.unit if self.unit is not None else sp.S(1)
@property
def shape(self) -> tuple[int, ...]:
match (self.rows, self.cols):
case (1, 1):
return ()
case (_, 1):
return (self.rows,)
case (1, _):
return (1, self.rows)
case (_, _):
return (self.rows, self.cols)
@property @property
def domain(self) -> sp.Interval | sp.Set: def domain(self) -> sp.Interval | sp.Set:
"""Return the domain of valid values for the symbol. """Return the scalar domain of valid values for each element of the symbol.
For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties. For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
This interval **must** have the property`start <= stop`. This interval **must** have the property`start <= stop`.
Otherwise, the domain is the symbolic set corresponding to `self.mathtype`. Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
""" """
if self.mathtype in [ match self.mathtype:
spux.MathType.Integer, case spux.MathType.Integer:
spux.MathType.Rational, return mk_interval(
spux.MathType.Real, self.interval_finite_z,
]: self.interval_inf,
return sp.Interval( self.interval_closed,
start=self.interval_finite[0] if not self.interval_inf[0] else -sp.oo, self.unit_factor,
end=self.interval_finite[1] if not self.interval_inf[1] else sp.oo, )
left_open=(
True if self.interval_inf[0] else not self.interval_closed[0]
),
right_open=(
True if self.interval_inf[1] else not self.interval_closed[1]
),
)
return self.mathtype.symbolic_set case spux.MathType.Rational:
return mk_interval(
Fraction(*self.interval_finite_q),
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Real:
return mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
)
case spux.MathType.Complex:
return (
mk_interval(
self.interval_finite_re,
self.interval_inf,
self.interval_closed,
self.unit_factor,
),
mk_interval(
self.interval_finite_im,
self.interval_inf_im,
self.interval_closed_im,
self.unit_factor,
),
)
#################### ####################
# - Properties # - Properties
#################### ####################
@property @property
def sp_symbol(self) -> sp.Symbol: def sp_symbol(self) -> sp.Symbol:
"""Return a symbolic variable corresponding to this `SimSymbol`. """Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined. As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
However, the assumptions system alone is rather limited, and implementations should therefore also strongly consider transporting `SimSymbols` directly, instead of `sp.Symbol`. - **MathType**: Depending on `self.mathtype`.
This allows making use of other properties like `self.domain`, when appropriate. - **Positive/Negative**: Depending on `self.domain`.
- **Nonzero**: Depending on `self.domain`, including open/closed boundary specifications.
Notes:
**The assumptions system is rather limited**, and implementations should strongly consider transporting `SimSymbols` instead of `sp.Symbol`.
This allows tracking ex. the valid interval domain for a symbol.
""" """
# MathType Domain Constraint # MathType Assumption
## -> We must feed the assumptions system.
mathtype_kwargs = {} mathtype_kwargs = {}
match self.mathtype: match self.mathtype:
case spux.MathType.Integer: case spux.MathType.Integer:
@ -165,53 +333,138 @@ class SimSymbol:
case spux.MathType.Complex: case spux.MathType.Complex:
mathtype_kwargs |= {'complex': True} mathtype_kwargs |= {'complex': True}
# Interval Constraints # Non-Zero Assumption
if isinstance(self.domain, sp.Interval): if (
# Assumption: Non-Zero (
if ( self.domain.left == 0
( and self.domain.left_open
self.domain.left == 0 or self.domain.right == 0
and self.domain.left_open and self.domain.right_open
or self.domain.right == 0 )
and self.domain.right_open or self.domain.left > 0
) or self.domain.right < 0
or self.domain.left > 0 ):
or self.domain.right < 0 mathtype_kwargs |= {'nonzero': True}
):
mathtype_kwargs |= {'nonzero': True}
# Assumption: Positive/Negative # Positive/Negative Assumption
if self.domain.left >= 0: if self.domain.left >= 0:
mathtype_kwargs |= {'positive': True} mathtype_kwargs |= {'positive': True}
elif self.domain.right <= 0: elif self.domain.right <= 0:
mathtype_kwargs |= {'negative': True} mathtype_kwargs |= {'negative': True}
# Construct the Symbol return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
return sp.Symbol(self.sim_node_name.name, **mathtype_kwargs)
####################
# - Operations
####################
def update(self, **kwargs) -> typ.Self:
def get_attr(attr: str):
_notfound = 'notfound'
if kwargs.get(attr, _notfound) is _notfound:
return getattr(self, attr)
return kwargs[attr]
return SimSymbol(
sym_name=get_attr('sym_name'),
mathtype=get_attr('mathtype'),
physical_type=get_attr('physical_type'),
unit=get_attr('unit'),
rows=get_attr('rows'),
cols=get_attr('cols'),
interval_finite_z=get_attr('interval_finite_z'),
interval_finite_q=get_attr('interval_finite_q'),
interval_finite_re=get_attr('interval_finite_q'),
interval_inf=get_attr('interval_inf'),
interval_closed=get_attr('interval_closed'),
interval_finite_im=get_attr('interval_finite_im'),
interval_inf_im=get_attr('interval_inf_im'),
interval_closed_im=get_attr('interval_closed_im'),
)
def set_size(self, rows: int, cols: int) -> typ.Self:
return SimSymbol(
sym_name=self.sym_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=rows,
cols=cols,
interval_finite_z=self.interval_finite_z,
interval_finite_q=self.interval_finite_q,
interval_finite_re=self.interval_finite_re,
interval_inf=self.interval_inf,
interval_closed=self.interval_closed,
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
)
#################### ####################
# - Common Sim Symbols # - Common Sim Symbols
#################### ####################
class CommonSimSymbol(enum.StrEnum): class CommonSimSymbol(enum.StrEnum):
"""A set of pre-defined symbols that might commonly be used in the context of physical simulation. """Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
Each entry maps directly to a particular `SimSymbol`. The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
The enum is compatible with `BLField`, making it easy to declare a UI-driven dropdown of symbols that behave as expected.
Attributes: Attributes:
X:
Time: A symbol representing a real-valued wavelength.
Wavelength: A symbol representing a real-valued wavelength. Wavelength: A symbol representing a real-valued wavelength.
Implicitly, this symbol often represents "vacuum wavelength" in particular. Implicitly, this symbol often represents "vacuum wavelength" in particular.
Wavelength: A symbol representing a real-valued frequency. Wavelength: A symbol representing a real-valued frequency.
Generally, this is the non-angular frequency. Generally, this is the non-angular frequency.
""" """
X = enum.auto() Index = enum.auto()
# Space|Time
SpaceX = enum.auto()
SpaceY = enum.auto()
SpaceZ = enum.auto()
AngR = enum.auto()
AngTheta = enum.auto()
AngPhi = enum.auto()
DirX = enum.auto()
DirY = enum.auto()
DirZ = enum.auto()
Time = enum.auto() Time = enum.auto()
# Fields
FieldEx = enum.auto()
FieldEy = enum.auto()
FieldEz = enum.auto()
FieldHx = enum.auto()
FieldHy = enum.auto()
FieldHz = enum.auto()
FieldEr = enum.auto()
FieldEtheta = enum.auto()
FieldEphi = enum.auto()
FieldHr = enum.auto()
FieldHtheta = enum.auto()
FieldHphi = enum.auto()
# Optics
Wavelength = enum.auto() Wavelength = enum.auto()
Frequency = enum.auto() Frequency = enum.auto()
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
Flux = enum.auto()
WaveVecX = enum.auto()
WaveVecY = enum.auto()
WaveVecZ = enum.auto()
####################
# - UI
####################
@staticmethod @staticmethod
def to_name(v: typ.Self) -> str: def to_name(v: typ.Self) -> str:
"""Convert the enum value to a human-friendly name. """Convert the enum value to a human-friendly name.
@ -222,7 +475,7 @@ class CommonSimSymbol(enum.StrEnum):
Returns: Returns:
A human-friendly name corresponding to the enum value. A human-friendly name corresponding to the enum value.
""" """
return CommonSimSymbol(v).sim_symbol_name.name return CommonSimSymbol(v).name
@staticmethod @staticmethod
def to_icon(_: typ.Self) -> str: def to_icon(_: typ.Self) -> str:
@ -241,55 +494,125 @@ class CommonSimSymbol(enum.StrEnum):
#################### ####################
@property @property
def name(self) -> str: def name(self) -> str:
return self.sim_symbol.name
@property
def sim_symbol_name(self) -> str:
SSN = SimSymbolName SSN = SimSymbolName
CSS = CommonSimSymbol CSS = CommonSimSymbol
return { return {
CSS.X: SSN.LowerX, CSS.Index: SSN.LowerI,
# Space|Time
CSS.SpaceX: SSN.LowerX,
CSS.SpaceY: SSN.LowerY,
CSS.SpaceZ: SSN.LowerZ,
CSS.AngR: SSN.LowerR,
CSS.AngTheta: SSN.LowerTheta,
CSS.AngPhi: SSN.LowerPhi,
CSS.DirX: SSN.LowerX,
CSS.DirY: SSN.LowerY,
CSS.DirZ: SSN.LowerZ,
CSS.Time: SSN.LowerT, CSS.Time: SSN.LowerT,
CSS.Wavelength: SSN.Wavelength, # Fields
CSS.FieldEx: SSN.Ex,
CSS.FieldEy: SSN.Ey,
CSS.FieldEz: SSN.Ez,
CSS.FieldHx: SSN.Hx,
CSS.FieldHy: SSN.Hy,
CSS.FieldHz: SSN.Hz,
CSS.FieldEr: SSN.Er,
CSS.FieldHr: SSN.Hr,
# Optics
CSS.Frequency: SSN.Frequency, CSS.Frequency: SSN.Frequency,
CSS.Wavelength: SSN.Wavelength,
CSS.DiffOrderX: SSN.DiffOrderX,
CSS.DiffOrderY: SSN.DiffOrderY,
}[self] }[self]
@property def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
def sim_symbol(self) -> SimSymbol:
"""Retrieve the `SimSymbol` associated with the `CommonSimSymbol`.""" """Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
CSS = CommonSimSymbol CSS = CommonSimSymbol
# Space
sym_space = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Length,
unit=unit,
)
sym_ang = SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Angle,
unit=unit,
)
# Fields
def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
return SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.EField
if eh == 'e'
else spux.PhysicalType.HField,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf_re=(False, True),
interval_closed_re=(True, False),
interval_finite_im=(sys.float_info.min, sys.float_info.max),
interval_inf_im=(True, True),
)
return { return {
CSS.X: SimSymbol( CSS.Index: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Integer,
physical_type=spux.PhysicalType.NonPhysical, interval_finite_z=(0, 2**64),
## TODO: Unit of Picosecond
interval_finite=(sys.float_info.min, sys.float_info.max),
interval_inf=(True, True),
interval_closed=(False, False),
),
CSS.Time: SimSymbol(
sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Time,
## TODO: Unit of Picosecond
interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(True, False), interval_closed=(True, False),
), ),
# Space|Time
CSS.SpaceX: sym_space,
CSS.SpaceY: sym_space,
CSS.SpaceZ: sym_space,
CSS.AngR: sym_space,
CSS.AngTheta: sym_ang,
CSS.AngPhi: sym_ang,
CSS.Time: SimSymbol(
sym_name=self.name,
physical_type=spux.PhysicalType.Time,
unit=unit,
interval_finite_re=(0, sys.float_info.max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Fields
CSS.FieldEx: sym_field('e'),
CSS.FieldEy: sym_field('e'),
CSS.FieldEz: sym_field('e'),
CSS.FieldHx: sym_field('h'),
CSS.FieldHy: sym_field('h'),
CSS.FieldHz: sym_field('h'),
CSS.FieldEr: sym_field('e'),
CSS.FieldEtheta: sym_field('e'),
CSS.FieldEphi: sym_field('e'),
CSS.FieldHr: sym_field('h'),
CSS.FieldHtheta: sym_field('h'),
CSS.FieldHphi: sym_field('h'),
CSS.Flux: SimSymbol(
sym_name=SimSymbolName.Flux,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Power,
unit=unit,
),
# Optics
CSS.Wavelength: SimSymbol( CSS.Wavelength: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length, physical_type=spux.PhysicalType.Length,
## TODO: Unit of Picosecond unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
CSS.Frequency: SimSymbol( CSS.Frequency: SimSymbol(
sim_node_name=self.sim_symbol_name, sym_name=self.name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
unit=unit,
interval_finite=(0, sys.float_info.max), interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
@ -298,9 +621,33 @@ class CommonSimSymbol(enum.StrEnum):
#################### ####################
# - Selected Direct Access # - Selected Direct-Access to SimSymbols
#################### ####################
x = CommonSimSymbol.X.sim_symbol idx = CommonSimSymbol.Index.sim_symbol
t = CommonSimSymbol.Time.sim_symbol t = CommonSimSymbol.Time.sim_symbol
wl = CommonSimSymbol.Wavelength.sim_symbol wl = CommonSimSymbol.Wavelength.sim_symbol
freq = CommonSimSymbol.Frequency.sim_symbol freq = CommonSimSymbol.Frequency.sim_symbol
space_x = CommonSimSymbol.SpaceX.sim_symbol
space_y = CommonSimSymbol.SpaceY.sim_symbol
space_z = CommonSimSymbol.SpaceZ.sim_symbol
dir_x = CommonSimSymbol.DirX.sim_symbol
dir_y = CommonSimSymbol.DirY.sim_symbol
dir_z = CommonSimSymbol.DirZ.sim_symbol
ang_r = CommonSimSymbol.AngR.sim_symbol
ang_theta = CommonSimSymbol.AngTheta.sim_symbol
ang_phi = CommonSimSymbol.AngPhi.sim_symbol
field_ex = CommonSimSymbol.FieldEx.sim_symbol
field_ey = CommonSimSymbol.FieldEy.sim_symbol
field_ez = CommonSimSymbol.FieldEz.sim_symbol
field_hx = CommonSimSymbol.FieldHx.sim_symbol
field_hy = CommonSimSymbol.FieldHx.sim_symbol
field_hz = CommonSimSymbol.FieldHx.sim_symbol
flux = CommonSimSymbol.Flux.sim_symbol
diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol