From 353a2c997e7850b3655c7880ca6d49c8edef2a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Tue, 21 May 2024 22:57:56 +0200 Subject: [PATCH] refactor: end-of-day commit (sim symbol flow for data import/export & inverse design) --- pyproject.toml | 1 + requirements-dev.lock | 15 + requirements.lock | 15 + .../contracts/flow_kinds/array.py | 155 ++-- .../contracts/flow_kinds/expr_info.py | 36 + .../contracts/flow_kinds/flow_kinds.py | 110 +-- .../contracts/flow_kinds/info.py | 365 ++++----- .../contracts/flow_kinds/lazy_func.py | 21 + .../contracts/flow_kinds/lazy_range.py | 484 ++++++------ .../contracts/flow_kinds/params.py | 71 +- .../maxwell_sim_nodes/contracts/sim_types.py | 20 +- .../contracts/unit_systems.py | 4 - .../nodes/analysis/extract_data.py | 733 +++++++----------- .../nodes/analysis/math/filter_math.py | 277 +++---- .../nodes/analysis/math/map_math.py | 184 +++-- .../nodes/analysis/math/transform_math.py | 57 +- .../maxwell_sim_nodes/nodes/analysis/viz.py | 132 ++-- .../maxwell_sim_nodes/nodes/events.py | 28 +- .../file_importers/data_file_importer.py | 132 +++- .../file_exporters/data_file_exporter.py | 6 +- .../maxwell_sim_nodes/sockets/expr.py | 125 +-- .../utils/extra_sympy_units.py | 101 ++- src/blender_maxwell/utils/image_ops.py | 137 ++-- src/blender_maxwell/utils/sim_symbols.py | 559 ++++++++++--- 24 files changed, 2058 insertions(+), 1710 deletions(-) create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py diff --git a/pyproject.toml b/pyproject.toml index 4aae8b2..8e32d91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ #"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen "certifi==2021.10.8", "polars>=0.20.26", + "seaborn[stats]>=0.13.2", ] ## When it comes to dev-dep conflicts: ## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them. diff --git a/requirements-dev.lock b/requirements-dev.lock index 37d0aa6..98731f6 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -81,6 +81,7 @@ locket==1.0.0 markupsafe==2.1.5 # via jinja2 matplotlib==3.8.3 + # via seaborn # via tidy3d ml-dtypes==0.4.0 # via jax @@ -102,8 +103,11 @@ numpy==1.24.3 # via ml-dtypes # via numba # via opt-einsum + # via patsy # via scipy + # via seaborn # via shapely + # via statsmodels # via tidy3d # via trimesh # via xarray @@ -114,11 +118,16 @@ packaging==24.0 # via dask # via h5netcdf # via matplotlib + # via statsmodels # via xarray pandas==2.2.1 + # via seaborn + # via statsmodels # via xarray partd==1.4.1 # via dask +patsy==0.5.6 + # via statsmodels pillow==10.2.0 # via matplotlib platformdirs==4.2.1 @@ -167,13 +176,19 @@ s3transfer==0.5.2 scipy==1.12.0 # via jax # via jaxlib + # via seaborn + # via statsmodels # via tidy3d +seaborn==0.13.2 setuptools==69.5.1 # via nodeenv shapely==2.0.3 # via tidy3d six==1.16.0 + # via patsy # via python-dateutil +statsmodels==0.14.2 + # via seaborn sympy==1.12 termcolor==2.4.0 # via commitizen diff --git a/requirements.lock b/requirements.lock index 9a5b1ea..4641252 100644 --- a/requirements.lock +++ b/requirements.lock @@ -59,6 +59,7 @@ llvmlite==0.42.0 locket==1.0.0 # via partd matplotlib==3.8.3 + # via seaborn # via tidy3d ml-dtypes==0.4.0 # via jax @@ -78,8 +79,11 @@ numpy==1.24.3 # via ml-dtypes # via numba # via opt-einsum + # via patsy # via scipy + # via seaborn # via shapely + # via statsmodels # via tidy3d # via trimesh # via xarray @@ -89,11 +93,16 @@ packaging==24.0 # via dask # via h5netcdf # via matplotlib + # via statsmodels # via xarray pandas==2.2.1 + # via seaborn + # via statsmodels # via xarray partd==1.4.1 # via dask +patsy==0.5.6 + # via statsmodels pillow==10.2.0 # via matplotlib polars==0.20.26 @@ -132,11 +141,17 @@ s3transfer==0.5.2 scipy==1.12.0 # via jax # via jaxlib + # via seaborn + # via statsmodels # via tidy3d +seaborn==0.13.2 shapely==2.0.3 # via tidy3d six==1.16.0 + # via patsy # via python-dateutil +statsmodels==0.14.2 + # via seaborn sympy==1.12 tidy3d==2.6.3 toml==0.10.2 diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py index 9aad6de..f33ffde 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py @@ -29,9 +29,12 @@ from blender_maxwell.utils import logger log = logger.get(__name__) +# TODO: Our handling of 'is_sorted' is sloppy and probably wrong. @dataclasses.dataclass(frozen=True, kw_only=True) class ArrayFlow: - """A simple, flat array of values with an optionally-attached unit. + """A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking. + + While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing. Attributes: values: An ND array-like object of arbitrary numerical type. @@ -44,13 +47,97 @@ class ArrayFlow: is_sorted: bool = False + #################### + # - Computed Properties + #################### + @property + def is_symbolic(self) -> bool: + """Always False, as ArrayFlows are never unrealized.""" + return False + def __len__(self) -> int: + """Outer length of the contained array.""" return len(self.values) @functools.cached_property def mathtype(self) -> spux.MathType: + """Deduce the `spux.MathType` of the first element of the contained array. + + This is generally a heuristic, but because `jax` enforces homogeneous arrays, this is actually a well-defined approach. + """ return spux.MathType.from_pytype(type(self.values.item(0))) + @functools.cached_property + def physical_type(self) -> spux.MathType: + """Deduce the `spux.PhysicalType` of the unit.""" + return spux.PhysicalType.from_unit(self.unit) + + #################### + # - Array Features + #################### + @property + def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']: + """Standardized access to `self.values`.""" + return self.values + + @functools.cached_property + def shape(self) -> int: + """Shape of the contained array.""" + return self.values.shape + + def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr: + """Implement indexing and slicing in a sane way. + + - **Integer Index**: For scalar output, return a `sympy` expression of the scalar multiplied by the unit, else just a sympy expression of the value. + - **Slice**: Slice the internal array directly, and wrap the result in a new `ArrayFlow`. + """ + if isinstance(subscript, slice): + return ArrayFlow( + values=self.values[subscript], + unit=self.unit, + is_sorted=self.is_sorted, + ) + + if isinstance(subscript, int): + value = self.values[subscript] + if len(value.shape) == 0: + return value * self.unit if self.unit is not None else sp.S(value) + return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted) + + raise NotImplementedError + + #################### + # - Methods + #################### + def rescale( + self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None + ) -> typ.Self: + """Apply an order-preserving function to each element of the array, then (optionally) transform the result w/new unit and/or order. + + An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`. + + Parameters: + rescale_func: An **order-preserving** function to apply to each array element. + reverse: Whether to reverse the order of the result. + new_unit: An (optional) new unit to scale the result to. + """ + # Compile JAX-Compatible Rescale Function + a = self.mathtype.sp_symbol_a + rescale_expr = ( + spux.scale_to_unit(rescale_func(a * self.unit), new_unit) + if self.unit is not None + else rescale_func(a) + ) + _rescale_func = sp.lambdify(a, rescale_expr, 'jax') + values = _rescale_func(self.values) + + # Return ArrayFlow + return ArrayFlow( + values=values[::-1] if reverse else values, + unit=new_unit, + is_sorted=self.is_sorted, + ) + def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: """Find the index of the value that is closest to the given value. @@ -88,56 +175,26 @@ class ArrayFlow: return right_idx - def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self: - if self.unit is not None: - return ArrayFlow( - values=self.values, unit=corrected_unit, is_sorted=self.is_sorted - ) + #################### + # - Unit Transforms + #################### + def correct_unit(self, unit: spux.Unit) -> typ.Self: + """Simply replace the existing unit with the given one. - msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"' - raise ValueError(msg) + Parameters: + corrected_unit: The new unit to insert. + **MUST** be associable with a well-defined `PhysicalType`. + """ + return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted) - def rescale_to_unit(self, unit: spu.Quantity | None) -> typ.Self: - ## TODO: Cache by unit would be a very nice speedup for Viz node. - if self.unit is not None: - return ArrayFlow( - values=float(spux.scaling_factor(self.unit, unit)) * self.values, - unit=unit, - is_sorted=self.is_sorted, - ) + def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self: + """Rescale the `ArrayFlow` to be expressed in the given unit. - if unit is None: - return self - - msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' - raise ValueError(msg) - - def rescale( - self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None - ) -> typ.Self: - # Compile JAX-Compatible Rescale Function - a = sp.Symbol('a') - rescale_expr = ( - spux.scale_to_unit(rescale_func(a * self.unit), new_unit) - if self.unit is not None - else rescale_func(a * self.unit) - ) - _rescale_func = sp.lambdify(a, rescale_expr, 'jax') - values = _rescale_func(self.values) - - # Return ArrayFlow - return ArrayFlow( - values=values[::-1] if reverse else values, - unit=new_unit, - is_sorted=self.is_sorted, - ) - - def __getitem__(self, subscript: slice): - if isinstance(subscript, slice): - return ArrayFlow( - values=self.values[subscript], - unit=self.unit, - is_sorted=self.is_sorted, - ) + Parameters: + corrected_unit: The new unit to insert. + **MUST** be associable with a well-defined `PhysicalType`. + """ + return self.rescale(lambda v: v, new_unit=new_unit) + def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self: raise NotImplementedError diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py new file mode 100644 index 0000000..412cd05 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/expr_info.py @@ -0,0 +1,36 @@ +# blender_maxwell +# Copyright (C) 2024 blender_maxwell Project Contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import typing as typ + +from blender_maxwell.utils import extra_sympy_units as spux + +from . import FlowKind + + +class ExprInfo(typ.TypedDict): + active_kind: FlowKind + size: spux.NumberSize1D + mathtype: spux.MathType + physical_type: spux.PhysicalType + + # Value + default_value: spux.SympyExpr + + # Range + default_min: spux.SympyExpr + default_max: spux.SympyExpr + default_steps: int diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py index 4a7b17a..d4058e9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py @@ -19,6 +19,7 @@ import typing as typ from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger +from blender_maxwell.utils.staticproperty import staticproperty log = logger.get(__name__) @@ -51,16 +52,71 @@ class FlowKind(enum.StrEnum): Capabilities = enum.auto() # Values - Value = enum.auto() - Array = enum.auto() + Value = enum.auto() ## 'value' + Array = enum.auto() ## 'array' # Lazy - Func = enum.auto() - Range = enum.auto() + Func = enum.auto() ## 'lazy_func' + Range = enum.auto() ## 'lazy_range' # Auxiliary - Params = enum.auto() - Info = enum.auto() + Params = enum.auto() ## 'params' + Info = enum.auto() ## 'info' + + #################### + # - UI + #################### + @staticmethod + def to_name(v: typ.Self) -> str: + return { + FlowKind.Capabilities: 'Capabilities', + # Values + FlowKind.Value: 'Value', + FlowKind.Array: 'Array', + # Lazy + FlowKind.Range: 'Range', + FlowKind.Func: 'Func', + # Auxiliary + FlowKind.Params: 'Params', + FlowKind.Info: 'Info', + }[v] + + @staticmethod + def to_icon(_: typ.Self) -> str: + return '' + + #################### + # - Static Properties + #################### + @staticproperty + def active_kinds() -> list[typ.Self]: + """Return a list of `FlowKind`s that are able to be considered "active". + + "Active" `FlowKind`s are considered the primary data type of a socket's flow. + For example, for sockets to be linkeable, their active `FlowKind` must generally match. + """ + return [ + FlowKind.Value, + FlowKind.Array, + FlowKind.Range, + FlowKind.Func, + ] + + @property + def socket_shape(self) -> str: + """Return the socket shape associated with this `FlowKind`. + + **ONLY** valid for `FlowKind`s that can be considered "active". + + Raises: + ValueError: If this `FlowKind` cannot ever be considered "active". + """ + return { + FlowKind.Value: 'CIRCLE', + FlowKind.Array: 'SQUARE', + FlowKind.Range: 'SQUARE', + FlowKind.Func: 'DIAMOND', + }[self] #################### # - Class Methods @@ -69,7 +125,7 @@ class FlowKind(enum.StrEnum): def scale_to_unit_system( cls, kind: typ.Self, - flow_obj, + flow_obj: spux.SympyExpr, unit_system: spux.UnitSystem, ): # log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj)) @@ -87,43 +143,3 @@ class FlowKind(enum.StrEnum): msg = 'Tried to scale unknown kind' raise ValueError(msg) - - #################### - # - Computed - #################### - @property - def flow_kind(self) -> str: - return { - FlowKind.Value: FlowKind.Value, - FlowKind.Array: FlowKind.Array, - FlowKind.Func: FlowKind.Func, - FlowKind.Range: FlowKind.Range, - }[self] - - @property - def socket_shape(self) -> str: - return { - FlowKind.Value: 'CIRCLE', - FlowKind.Array: 'SQUARE', - FlowKind.Range: 'SQUARE', - FlowKind.Func: 'DIAMOND', - }[self] - - #################### - # - Blender Enum - #################### - @staticmethod - def to_name(v: typ.Self) -> str: - return { - FlowKind.Capabilities: 'Capabilities', - FlowKind.Value: 'Value', - FlowKind.Array: 'Array', - FlowKind.Range: 'Range', - FlowKind.Func: 'Func', - FlowKind.Params: 'Parameters', - FlowKind.Info: 'Information', - }[v] - - @staticmethod - def to_icon(_: typ.Self) -> str: - return '' diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py index e1252fd..f717b9a 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py @@ -18,246 +18,247 @@ import dataclasses import functools import typing as typ -import jax - from blender_maxwell.utils import extra_sympy_units as spux -from blender_maxwell.utils import logger +from blender_maxwell.utils import logger, sim_symbols from .array import ArrayFlow from .lazy_range import RangeFlow log = logger.get(__name__) +LabelArray: typ.TypeAlias = list[str] + +# IndexArray: Identifies Discrete Dimension Values +## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index. +## -> RangeFlow (rat|real): Index by unrealized array scaled between boundaries. +## -> LabelArray (int): For int-index arrays, interpret using these labels. +## -> None: Non-Discrete/unrealized indexing; use 'dim.domain'. +IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None + @dataclasses.dataclass(frozen=True, kw_only=True) class InfoFlow: - #################### - # - Covariant Input - #################### - dim_names: list[str] = dataclasses.field(default_factory=list) - dim_idx: dict[str, ArrayFlow | RangeFlow] = dataclasses.field( - default_factory=dict - ) ## TODO: Rename to dim_idxs + """Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`. - @functools.cached_property - def dim_has_coords(self) -> dict[str, int]: - return { - dim_name: not ( - isinstance(dim_idx, RangeFlow) - and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite) - ) - for dim_name, dim_idx in self.dim_idx.items() - } + Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra. - @functools.cached_property - def dim_lens(self) -> dict[str, int]: - return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} + # From Arrays to Tensors + The best way to illustrate how it works is to specify how raw-array concepts map to an array described by an `InfoFlow`: - @functools.cached_property - def dim_mathtypes(self) -> dict[str, spux.MathType]: - return { - dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() - } + - **Index**: In raw arrays, the "index" is generally constrained to an integer ring, and has no semantic meaning. + **(Covariant) Dimension**: The "dimension" is an named "index array", which assigns each integer index a **scalar** value of particular mathematical type, name, and unit (if not unitless). + - **Value**: In raw arrays, the "value" is some particular computational type, or another raw array. + **(Contravariant) Output**: The "output" is a strictly named, sized object that can only be produced - @functools.cached_property - def dim_units(self) -> dict[str, spux.Unit]: - return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} + In essence, `InfoFlow` allows us to treat raw data as a tensor, then operate on its dimensionality as split into parts whose transform varies _with_ the output (aka. a _covariant_ index), and parts whose transform varies _against_ the output (aka. _contravariant_ value). - @functools.cached_property - def dim_physical_types(self) -> dict[str, spux.PhysicalType]: - return { - dim_name: spux.PhysicalType.from_unit(dim_idx.unit) - for dim_name, dim_idx in self.dim_idx.items() - } + ## Benefits + The reasons to do this are numerous: - @functools.cached_property - def dim_idx_arrays(self) -> list[jax.Array]: - return [ - dim_idx.realize().values - if isinstance(dim_idx, RangeFlow) - else dim_idx.values - for dim_idx in self.dim_idx.values() - ] + - **Clarity**: Using `InfoFlow`, it's easy to understand what the data is, and what can be done to it, making it much easier to implement complex operations in math nodes without sacrificing the user's mental model. + - **Zero-Cost Operations**: Transforming indices, "folding" dimensions into the output, and other such operations don't actually do anything to the data, enabling a lot of operations to feel "free" in terms of performance. + - **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers. + """ #################### - # - Contravariant Output + # - Dimensions: Covariant Index #################### - # Output Information - ## TODO: Add PhysicalType - output_name: str = dataclasses.field(default_factory=list) - output_shape: tuple[int, ...] | None = dataclasses.field(default=None) - output_mathtype: spux.MathType = dataclasses.field() - output_unit: spux.Unit | None = dataclasses.field() - - @property - def output_shape_len(self) -> int: - if self.output_shape is None: - return 0 - return len(self.output_shape) - - # Pinned Dimension Information - ## TODO: Add PhysicalType - pinned_dim_names: list[str] = dataclasses.field(default_factory=list) - pinned_dim_values: dict[str, float | complex] = dataclasses.field( + dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field( default_factory=dict ) - pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field( - default_factory=dict - ) - pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict) + + @functools.cached_property + def last_dim(self) -> sim_symbols.SimSymbol | None: + """The integer axis occupied by the dimension. + + Can be used to index `.shape` of the represented raw array. + """ + if self.dims: + return next(iter(self.dims.keys())) + return None + + @functools.cached_property + def last_dim(self) -> sim_symbols.SimSymbol | None: + """The integer axis occupied by the dimension. + + Can be used to index `.shape` of the represented raw array. + """ + if self.dims: + return list(self.dims.keys())[-1] + return None + + def dim_axis(self, dim: sim_symbols.SimSymbol) -> int: + """The integer axis occupied by the dimension. + + Can be used to index `.shape` of the represented raw array. + """ + return list(self.dims.keys()).index(dim) + + def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool: + """Whether the dim's index is continuous, and therefore index array. + + This happens when the dimension is generated from a symbolic function, as opposed to from discrete observations. + In these cases, the `SimSymbol.domain` of the dimension should be used to determine the overall domain of validity. + + Other than that, it's up to the user to select a particular way of indexing. + """ + return self.dims[dim] is None + + def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool: + """Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`.""" + return isinstance(self.dims[dim], ArrayFlow | RangeFlow) + + def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool: + """Whether the (int) dim is indexed by a `LabelArray`.""" + if dim.mathtype is spux.MathType.Integer: + return isinstance(self.dims[dim], list) + return False #################### - # - Methods + # - Output: Contravariant Value #################### - def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self: + output: sim_symbols.SimSymbol + + #################### + # - Pinned Dimension Values + #################### + ## -> Whenever a dimension is deleted, we retain what that index value was. + ## -> This proves to be very helpful for clear visualization. + pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field( + default_factory=dict + ) + + #################### + # - Operations: Dimensions + #################### + def prepend_dim( + self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol + ) -> typ.Self: + """Insert a new dimension at index 0.""" return InfoFlow( - # Dimensions - dim_names=self.dim_names, - dim_idx={ - _dim_name: ( - dim_idx - if _dim_name != dim_name - else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]] - ) - for _dim_name, dim_idx in self.dim_idx.items() + dims={dim: dim_idx} | self.dims, + output=self.output, + pinned_values=self.pinned_values, + ) + + def slice_dim( + self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int] + ) -> typ.Self: + """Slice a dimensional array by-index along a particular dimension.""" + return InfoFlow( + dims={ + _dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]] + if _dim == dim + else _dim + for _dim, dim_idx in self.dims.items() }, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + output=self.output, + pinned_values=self.pinned_values, ) def replace_dim( - self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | RangeFlow] + self, + old_dim: sim_symbols.SimSymbol, + new_dim: sim_symbols.SimSymbol, + new_dim_idx: IndexArray, ) -> typ.Self: - """Replace a dimension (and its indexing) with a new name and index array/range.""" + """Replace a dimension entirely, in-place, including symbol and index array.""" return InfoFlow( - # Dimensions - dim_names=[ - dim_name if dim_name != old_dim_name else new_dim_idx[0] - for dim_name in self.dim_names - ], - dim_idx={ - (dim_name if dim_name != old_dim_name else new_dim_idx[0]): ( - dim_idx if dim_name != old_dim_name else new_dim_idx[1] + dims={ + (new_dim if _dim == old_dim else _dim): ( + new_dim_idx if _dim == old_dim else _dim ) - for dim_name, dim_idx in self.dim_idx.items() + for _dim, dim_idx in self.dims.items() }, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + output=self.output, + pinned_values=self.pinned_values, ) - def rescale_dim_idxs(self, new_dim_idxs: dict[str, RangeFlow]) -> typ.Self: + def replace_dims( + self, new_dims: dict[sim_symbols.SimSymbol, IndexArray] + ) -> typ.Self: """Replace several dimensional indices with new index arrays/ranges.""" return InfoFlow( - # Dimensions - dim_names=self.dim_names, - dim_idx={ - _dim_name: new_dim_idxs.get(_dim_name, dim_idx) - for _dim_name, dim_idx in self.dim_idx.items() + dims={ + dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items() }, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + output=self.output, + pinned_values=self.pinned_values, ) - def delete_dimension(self, dim_name: str) -> typ.Self: - """Delete a dimension.""" + def delete_dim( + self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None + ) -> typ.Self: + """Delete a dimension, optionally pinning the value of an index from that dimension.""" + new_pin = ( + {dim_to_remove: self.dims[dim_to_remove][pin_idx]} + if pin_idx is not None + else {} + ) return InfoFlow( - # Dimensions - dim_names=[ - _dim_name for _dim_name in self.dim_names if _dim_name != dim_name - ], - dim_idx={ - _dim_name: dim_idx - for _dim_name, dim_idx in self.dim_idx.items() - if _dim_name != dim_name + dims={ + dim: dim_idx + for dim, dim_idx in self.dims.items() + if dim != dim_to_remove }, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + output=self.output, + pinned_values=self.pinned_values | new_pin, ) - def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self: - """Swap the position of two dimensions.""" + def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self: + """Swap the positions of two dimensions.""" - # Compute Swapped Dimension Name List + # Swapped Dimension Keys def name_swapper(dim_name): return ( dim_name - if dim_name not in [dim_0_name, dim_1_name] - else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name] + if dim_name not in [dim_0, dim_1] + else {dim_0: dim_1, dim_1: dim_0}[dim_name] ) - dim_names = [name_swapper(dim_name) for dim_name in self.dim_names] + swapped_dim_keys = [name_swapper(dim) for dim in self.dims] - # Compute Info return InfoFlow( - # Dimensions - dim_names=dim_names, - dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names}, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys}, + output=self.output, + pinned_values=self.pinned_values, ) - def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self: - """Set the MathType of the output.""" + #################### + # - Operations: Output + #################### + def update_output(self, **kwargs) -> typ.Self: + """Passthrough to `SimSymbol.update()` method on `self.output`.""" return InfoFlow( - dim_names=self.dim_names, - dim_idx=self.dim_idx, - # Outputs - output_name=self.output_name, - output_shape=self.output_shape, - output_mathtype=output_mathtype, - output_unit=self.output_unit, + dims=self.dims, + output=self.output.update(**kwargs), + pinned_values=self.pinned_values, ) - def collapse_output( - self, - collapsed_name: str, - collapsed_mathtype: spux.MathType, - collapsed_unit: spux.Unit, - ) -> typ.Self: - """Replace the (scalar) output with the given corrected values.""" - return InfoFlow( - # Dimensions - dim_names=self.dim_names, - dim_idx=self.dim_idx, - output_name=collapsed_name, - output_shape=None, - output_mathtype=collapsed_mathtype, - output_unit=collapsed_unit, - ) + #################### + # - Operations: Fold + #################### + def fold_last_input(self): + """Fold the last input dimension into the output.""" + last_key = list(self.dims.keys())[-1] + last_idx = list(self.dims.values())[-1] + + rows = self.output.rows + cols = self.output.cols + match (rows, cols): + case (1, 1): + new_output = self.output.set_size(len(last_idx), 1) + case (_, 1): + new_output = self.output.set_size(rows, len(last_idx)) + case (1, _): + new_output = self.output.set_size(len(last_idx), cols) + case (_, _): + raise NotImplementedError ## Not yet :) - @functools.cached_property - def shift_last_input(self): - """Shift the last input dimension to the output.""" return InfoFlow( - # Dimensions - dim_names=self.dim_names[:-1], - dim_idx={ - dim_name: dim_idx - for dim_name, dim_idx in self.dim_idx.items() - if dim_name != self.dim_names[-1] + dims={ + dim: dim_idx for dim, dim_idx in self.dims.items() if dim != last_key }, - # Outputs - output_name=self.output_name, - output_shape=( - (self.dim_lens[self.dim_names[-1]],) - if self.output_shape is None - else (self.dim_lens[self.dim_names[-1]], *self.output_shape) - ), - output_mathtype=self.output_mathtype, - output_unit=self.output_unit, + output=new_output, + pinned_values=self.pinned_values, ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py index 5f96504..bb3b087 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py @@ -24,6 +24,8 @@ import jax from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger +from .params import ParamsFlow + log = logger.get(__name__) LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any] @@ -307,6 +309,25 @@ class FuncFlow: msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False' raise ValueError(msg) + #################### + # - Realization + #################### + def realize( + self, + params: ParamsFlow, + unit_system: spux.UnitSystem | None = None, + symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), + ) -> typ.Self: + if self.supports_jax: + return self.func_jax( + *params.scaled_func_args(unit_system, symbol_values), + *params.scaled_func_kwargs(unit_system, symbol_values), + ) + return self.func( + *params.scaled_func_args(unit_system, symbol_values), + *params.scaled_func_kwargs(unit_system, symbol_values), + ) + #################### # - Composition Operations #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py index e80a528..7d10636 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py @@ -28,13 +28,20 @@ from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import logger from .array import ArrayFlow -from .flow_kinds import FlowKind from .lazy_func import FuncFlow log = logger.get(__name__) class ScalingMode(enum.StrEnum): + """Identifier for how to space steps between two boundaries. + + Attributes: + Lin: Uniform spacing between two endpoints. + Geom: Log spacing between two endpoints, given as values. + Log: Log spacing between two endpoints, given as powers of a common base. + """ + Lin = enum.auto() Geom = enum.auto() Log = enum.auto() @@ -55,36 +62,20 @@ class ScalingMode(enum.StrEnum): @dataclasses.dataclass(frozen=True, kw_only=True) class RangeFlow: - r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation. + r"""Represents a spaced array using symbolic boundary expressions. - # Advantages Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous. - ## Memory + # Memory Scaling `ArrayFlow` generally has a memory scaling of $O(n)$. Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored. - ## Symbolic - Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol). + # Symbolic Bounds + `self.start` and `self.stop` boundary points are symbolic expressions, within which any element of `self.symbols` can participate. - One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow. + **It is the user's responsibility** to ensure that `self.start < self.stop`. - ## Performant Unit-Aware Operations - While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$. - `RangeFlow`, by contrast, scales as $O(1)$. - - As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity - - ## High-Performance Composition and Gradiant - With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`. - There are two nice things about this: - - - **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`. - - **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_. - - Thus, as part of larger computations, the performance properties of `RangeFlow` is extremely favorable. - - ## Numerical Properties + # Numerical Properties Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated. Attributes: @@ -108,8 +99,11 @@ class RangeFlow: unit: spux.Unit | None = None - symbols: frozenset[spux.IntSymbol] = frozenset() + symbols: frozenset[spux.Symbol] = frozenset() + #################### + # - Computed Properties + #################### @functools.cached_property def sorted_symbols(self) -> list[sp.Symbol]: """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. @@ -121,6 +115,19 @@ class RangeFlow: """ return sorted(self.symbols, key=lambda sym: sym.name) + @property + def is_symbolic(self) -> bool: + """Whether the `RangeFlow` has unrealized symbols.""" + return len(self.symbols) > 0 + + def __len__(self) -> int: + """Compute the length of the array that would be realized. + + Returns: + The number of steps. + """ + return self.steps + @functools.cached_property def mathtype(self) -> spux.MathType: """Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`. @@ -156,13 +163,206 @@ class RangeFlow: ) return combined_mathtype - def __len__(self): - """Compute the length of the array to be realized. + #################### + # - Methods + #################### + def rescale( + self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None + ) -> typ.Self: + """Apply an order-preserving function to each bound, then (optionally) transform the result w/new unit and/or order. + + An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`. + + Parameters: + rescale_func: An **order-preserving** function to apply to each array element. + reverse: Whether to reverse the order of the result. + new_unit: An (optional) new unit to scale the result to. + """ + new_pre_start = self.start if not reverse else self.stop + new_pre_stop = self.stop if not reverse else self.start + + new_start = rescale_func(new_pre_start * self.unit) + new_stop = rescale_func(new_pre_stop * self.unit) + + return RangeFlow( + start=( + spux.scale_to_unit(new_start, new_unit) + if new_unit is not None + else new_start + ), + stop=( + spux.scale_to_unit(new_stop, new_unit) + if new_unit is not None + else new_stop + ), + steps=self.steps, + scaling=self.scaling, + unit=new_unit, + symbols=self.symbols, + ) + + def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int: + raise NotImplementedError + + #################### + # - Exporters + #################### + @functools.cached_property + def array_generator( + self, + ) -> typ.Callable[ + [int | float | complex, int | float | complex, int], + jtyp.Inexact[jtyp.Array, ' steps'], + ]: + """Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods. Returns: - The number of steps. + A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array. """ - return self.steps + jnp_nspace = { + ScalingMode.Lin: jnp.linspace, + ScalingMode.Geom: jnp.geomspace, + ScalingMode.Log: jnp.logspace, + }.get(self.scaling) + + if jnp_nspace is None: + msg = f'ArrayFlow scaling method {self.scaling} is unsupported' + raise RuntimeError(msg) + return jnp_nspace + + @functools.cached_property + def as_func( + self, + ) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]: + """Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`. + + Notes: + The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols. + + Returns: + A `FuncFlow` that, given the input symbols defined in `self.symbols`, + """ + # Compile JAX Functions for Start/End Expressions + ## -> FYI, JAX-in-JAX works perfectly fine. + start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax') + stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax') + + # Compile ArrayGen Function + def gen_array( + *args: list[int | float | complex], + ) -> jtyp.Inexact[jtyp.Array, ' steps']: + return self.array_generator(start_jax(*args), stop_jax(*args), self.steps) + + # Return ArrayGen Function + return gen_array + + @functools.cached_property + def as_lazy_func(self) -> FuncFlow: + """Creates a `FuncFlow` using the output of `self.as_func`. + + This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array. + + Notes: + The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`. + + Returns: + A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings. + """ + return FuncFlow( + func=self.as_func, + func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols], + supports_jax=True, + ) + + #################### + # - Realization + #################### + def realize_start( + self, + symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), + ) -> int | float | complex: + """Realize the start-bound by inserting particular values for each symbol.""" + return spux.sympy_to_python( + self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols}) + ) + + def realize_stop( + self, + symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), + ) -> int | float | complex: + """Realize the stop-bound by inserting particular values for each symbol.""" + return spux.sympy_to_python( + self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols}) + ) + + def realize_step_size( + self, + symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), + ) -> int | float | complex: + """Realize the stop-bound by inserting particular values for each symbol.""" + if self.scaling is not ScalingMode.Lin: + raise NotImplementedError('Non-linear scaling mode not yet suported') + + raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps + + if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer(): + return int(raw_step_size) + return raw_step_size + + def realize( + self, + symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), + ) -> ArrayFlow: + """Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array. + + Parameters: + symbol_values: The particular values for each symbol, which will be inserted into the expression of each bound to realize them. + + Returns: + An `ArrayFlow` containing this realized `RangeFlow`. + """ + ## TODO: Check symbol values for coverage. + + return ArrayFlow( + values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]), + unit=self.unit, + is_sorted=True, + ) + + @functools.cached_property + def realize_array(self) -> ArrayFlow: + """Standardized access to `self.realize()` when there are no symbols.""" + return self.realize() + + def __getitem__(self, subscript: slice): + """Implement indexing and slicing in a sane way. + + - **Integer Index**: Not yet implemented. + - **Slice**: Return the `RangeFlow` that creates the same `ArrayFlow` as would be created by computing `self.realize_array`, then slicing that. + """ + if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin: + # Parse Slice + start = subscript.start if subscript.start is not None else 0 + stop = subscript.stop if subscript.stop is not None else self.steps + step = subscript.step if subscript.step is not None else 1 + + slice_steps = (stop - start + step - 1) // step + + # Compute New Start/Stop + step_size = self.realize_step_size() + new_start = step_size * start + new_stop = new_start + step_size * slice_steps + + return RangeFlow( + start=sp.S(new_start), + stop=sp.S(new_stop), + steps=slice_steps, + scaling=self.scaling, + unit=self.unit, + symbols=self.symbols, + ) + + raise NotImplementedError #################### # - Units @@ -264,231 +464,3 @@ class RangeFlow: f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}' ) raise ValueError(msg) - - #################### - # - Bound Operations - #################### - def rescale( - self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None - ) -> typ.Self: - new_pre_start = self.start if not reverse else self.stop - new_pre_stop = self.stop if not reverse else self.start - - new_start = rescale_func(new_pre_start * self.unit) - new_stop = rescale_func(new_pre_stop * self.unit) - - return RangeFlow( - start=( - spux.scale_to_unit(new_start, new_unit) - if new_unit is not None - else new_start - ), - stop=( - spux.scale_to_unit(new_stop, new_unit) - if new_unit is not None - else new_stop - ), - steps=self.steps, - scaling=self.scaling, - unit=new_unit, - symbols=self.symbols, - ) - - def rescale_bounds( - self, - rescale_func: typ.Callable[ - [spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr - ], - reverse: bool = False, - ) -> typ.Self: - """Apply a function to the bounds, effectively rescaling the represented array. - - Notes: - **It is presumed that the bounds are scaled with the same factor**. - Breaking this presumption may have unexpected results. - - The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced. - - Parameters: - scaler: The function that scales each bound. - reverse: Whether to reverse the bounds after running the `scaler`. - - Returns: - A rescaled `RangeFlow`. - """ - return RangeFlow( - start=rescale_func(self.start if not reverse else self.stop), - stop=rescale_func(self.stop if not reverse else self.start), - steps=self.steps, - scaling=self.scaling, - unit=self.unit, - symbols=self.symbols, - ) - - #################### - # - Lazy Representation - #################### - @functools.cached_property - def array_generator( - self, - ) -> typ.Callable[ - [int | float | complex, int | float | complex, int], - jtyp.Inexact[jtyp.Array, ' steps'], - ]: - """Compute the correct `jnp.*space` array generator, where `*` is one of the supported scaling methods. - - Returns: - A `jax` function that takes a valid `start`, `stop`, and `steps`, and returns a 1D `jax` array. - """ - jnp_nspace = { - ScalingMode.Lin: jnp.linspace, - ScalingMode.Geom: jnp.geomspace, - ScalingMode.Log: jnp.logspace, - }.get(self.scaling) - if jnp_nspace is None: - msg = f'ArrayFlow scaling method {self.scaling} is unsupported' - raise RuntimeError(msg) - - return jnp_nspace - - @functools.cached_property - def as_func( - self, - ) -> typ.Callable[[int | float | complex, ...], jtyp.Inexact[jtyp.Array, ' steps']]: - """Create a function that can compute the non-lazy output array as a function of the symbols in the expressions for `start` and `stop`. - - Notes: - The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols. - - Returns: - A `FuncFlow` that, given the input symbols defined in `self.symbols`, - """ - # Compile JAX Functions for Start/End Expressions - ## FYI, JAX-in-JAX works perfectly fine. - start_jax = sp.lambdify(self.symbols, self.start, 'jax') - stop_jax = sp.lambdify(self.symbols, self.stop, 'jax') - - # Compile ArrayGen Function - def gen_array( - *args: list[int | float | complex], - ) -> jtyp.Inexact[jtyp.Array, ' steps']: - return self.array_generator(start_jax(*args), stop_jax(*args), self.steps) - - # Return ArrayGen Function - return gen_array - - @functools.cached_property - def as_lazy_func(self) -> FuncFlow: - """Creates a `FuncFlow` using the output of `self.as_func`. - - This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array. - - Notes: - The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`. - - Returns: - A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings. - """ - return FuncFlow( - func=self.as_func, - func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols], - supports_jax=True, - ) - - #################### - # - Realization - #################### - def realize_start( - self, - symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), - ) -> ArrayFlow | FuncFlow: - return spux.sympy_to_python( - self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols}) - ) - - def realize_stop( - self, - symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), - ) -> ArrayFlow | FuncFlow: - return spux.sympy_to_python( - self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols}) - ) - - def realize_step_size( - self, - symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), - ) -> ArrayFlow | FuncFlow: - raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps - - if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer(): - return int(raw_step_size) - return raw_step_size - - def realize( - self, - symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}), - kind: typ.Literal[FlowKind.Array, FlowKind.Func] = FlowKind.Array, - ) -> ArrayFlow | FuncFlow: - """Apply a function to the bounds, effectively rescaling the represented array. - - Notes: - **It is presumed that the bounds are scaled with the same factor**. - Breaking this presumption may have unexpected results. - - The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced. - - Parameters: - scaler: The function that scales each bound. - reverse: Whether to reverse the bounds after running the `scaler`. - - Returns: - A rescaled `RangeFlow`. - """ - if not set(self.symbols).issubset(set(symbol_values.keys())): - msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})' - raise ValueError(msg) - - # Realize Symbols - realized_start = self.realize_start(symbol_values) - realized_stop = self.realize_stop(symbol_values) - - # Return Linspace / Logspace - def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']: - return self.array_generator(realized_start, realized_stop, self.steps) - - if kind == FlowKind.Array: - return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True) - if kind == FlowKind.Func: - return FuncFlow(func=gen_array, supports_jax=True) - - msg = f'Invalid kind: {kind}' - raise TypeError(msg) - - @functools.cached_property - def realize_array(self) -> ArrayFlow: - return self.realize() - - def __getitem__(self, subscript: slice): - if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin: - # Parse Slice - start = subscript.start if subscript.start is not None else 0 - stop = subscript.stop if subscript.stop is not None else self.steps - step = subscript.step if subscript.step is not None else 1 - - slice_steps = (stop - start + step - 1) // step - - # Compute New Start/Stop - step_size = self.realize_step_size() - new_start = step_size * start - new_stop = new_start + step_size * slice_steps - - return RangeFlow( - start=sp.S(new_start), - stop=sp.S(new_stop), - steps=slice_steps, - scaling=self.scaling, - unit=self.unit, - symbols=self.symbols, - ) - - raise NotImplementedError diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py index 1fdaafc..2d6dcae 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py @@ -22,35 +22,22 @@ from types import MappingProxyType import sympy as sp from blender_maxwell.utils import extra_sympy_units as spux -from blender_maxwell.utils import logger +from blender_maxwell.utils import logger, sim_symbols +from .expr_info import ExprInfo from .flow_kinds import FlowKind -from .info import InfoFlow + +# from .info import InfoFlow log = logger.get(__name__) -class ExprInfo(typ.TypedDict): - active_kind: FlowKind - size: spux.NumberSize1D - mathtype: spux.MathType - physical_type: spux.PhysicalType - - # Value - default_value: spux.SympyExpr - - # Range - default_min: spux.SympyExpr - default_max: spux.SympyExpr - default_steps: int - - @dataclasses.dataclass(frozen=True, kw_only=True) class ParamsFlow: func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) - symbols: frozenset[spux.Symbol] = frozenset() + symbols: frozenset[sim_symbols.SimSymbol] = frozenset() @functools.cached_property def sorted_symbols(self) -> list[sp.Symbol]: @@ -66,21 +53,18 @@ class ParamsFlow: #################### def scaled_func_args( self, - unit_system: spux.UnitSystem, - symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), + unit_system: spux.UnitSystem | None = None, + symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( + {} + ), ): """Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`. - For all `arg`s in `self.func_args`, the following operations are performed: - - **Unit System**: If `arg` - + For all `arg`s in `self.func_args`, the following operations are performed. Notes: This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method: - """ - - """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" if not all(sym in self.symbols for sym in symbol_values): msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" raise ValueError(msg) @@ -97,7 +81,7 @@ class ParamsFlow: def scaled_func_kwargs( self, - unit_system: spux.UnitSystem, + unit_system: spux.UnitSystem | None = None, symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), ): """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" @@ -145,9 +129,7 @@ class ParamsFlow: #################### # - Generate ExprSocketDef #################### - def sym_expr_infos( - self, info: InfoFlow, use_range: bool = False - ) -> dict[str, ExprInfo]: + def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]: """Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`. Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`. @@ -169,26 +151,35 @@ class ParamsFlow: The `ExprInfo`s can be directly defererenced `**expr_info`) """ + for sim_sym in self.sorted_symbols: + if use_range and sim_sym.mathtype is spux.MathType.Complex: + msg = 'No support for complex range in ExprInfo' + raise NotImplementedError(msg) + if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1): + msg = 'No support for non-scalar elements of range in ExprInfo' + raise NotImplementedError(msg) + if sim_sym.rows > 3 or sim_sym.cols > 1: + msg = 'No support for >Vec3 / Matrix values in ExprInfo' + raise NotImplementedError(msg) return { - sym.name: { + sim_sym.name: { # Declare Kind/Size ## -> Kind: Value prevents user-alteration of config. ## -> Size: Always scalar, since symbols are scalar (for now). - 'active_kind': FlowKind.Value, + 'active_kind': FlowKind.Value if not use_range else FlowKind.Range, 'size': spux.NumberSize1D.Scalar, # Declare MathType/PhysicalType ## -> MathType: Lookup symbol name in info dimensions. ## -> PhysicalType: Same. - 'mathtype': info.dim_mathtypes[sym.name], - 'physical_type': info.dim_physical_types[sym.name], - # TODO: Default Values + 'mathtype': self.dims[sim_sym].mathtype, + 'physical_type': self.dims[sim_sym].physical_type, + # TODO: Default Value # FlowKind.Value: Default Value #'default_value': # FlowKind.Range: Default Min/Max/Steps - #'default_min': - #'default_max': - #'default_steps': + 'default_min': sim_sym.domain.start, + 'default_max': sim_sym.domain.end, + 'default_steps': 50, } - for sym in self.sorted_symbols - if sym.name in info.dim_names + for sim_sym in self.sorted_symbols } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py index 613f2ae..ff160b2 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py @@ -489,11 +489,11 @@ class DataFileFormat(enum.StrEnum): E = DataFileFormat match self: case E.Csv: - return len(info.dim_names) + info.output_shape_len <= 2 + return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2 case E.Npy: return True case E.Txt | E.TxtGz: - return len(info.dim_names) + info.output_shape_len <= 2 + return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2 @property def saver( @@ -510,9 +510,9 @@ class DataFileFormat(enum.StrEnum): # Extract Input Coordinates dim_columns = { - dim_name: np.array(info.dim_idx_arrays[i]) - for i, dim_name in enumerate(info.dim_names) - } + dim.name: np.array(dim_idx.realize_array) + for i, (dim, dim_idx) in enumerate(info.dims) + } ## TODO: realize_array might not be defined on some index arrays # Declare Function to Extract Output Values output_columns = {} @@ -524,14 +524,14 @@ class DataFileFormat(enum.StrEnum): output_idx_str = f'[{output_idx}]' if use_output_idx else '' if bool(np.any(np.iscomplex(data_col))): output_columns |= { - f'{info.output_name}{output_idx_str}_re': np.real(data_col), - f'{info.output_name}{output_idx_str}_im': np.imag(data_col), + f'{info.output.name}{output_idx_str}_re': np.real(data_col), + f'{info.output.name}{output_idx_str}_im': np.imag(data_col), } # Else: Use Array Directly else: output_columns |= { - f'{info.output_name}{output_idx_str}': data_col, + f'{info.output.name}{output_idx_str}': data_col, } ## TODO: Maybe a check to ensure dtype!=object? @@ -605,11 +605,11 @@ class DataFileFormat(enum.StrEnum): E = DataFileFormat match self: case E.Csv: - return len(info.dim_names) + (info.output_shape_len + 1) <= 2 + return len(info.dims) + (info.output.rows + input.outputs.cols - 1) <= 2 case E.Npy: return True case E.Txt | E.TxtGz: - return len(info.dim_names) + (info.output_shape_len + 1) <= 2 + return len(info.dims) + (info.output.rows + info.output.cols - 1) <= 2 def supports_metadata(self) -> bool: E = DataFileFormat diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py index 6200b6f..0dc4752 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py @@ -48,7 +48,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | { # Electrodynamics _PT.CurrentDensity: spu.ampere / spu.um**2, _PT.Conductivity: spu.siemens / spu.um, - _PT.PoyntingVector: spu.watt / spu.um**2, _PT.EField: spu.volt / spu.um, _PT.HField: spu.ampere / spu.um, # Mechanical @@ -58,7 +57,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | { _PT.Force: spux.micronewton, # Luminal # Optics - _PT.PoyntingVector: spu.watt / spu.um**2, } ## TODO: Load (dynamically?) from addon preferences UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | { @@ -75,11 +73,9 @@ UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | { # Electrodynamics _PT.CurrentDensity: spu.ampere / spu.um**2, _PT.Conductivity: spu.siemens / spu.um, - _PT.PoyntingVector: spu.watt / spu.um**2, _PT.EField: spu.volt / spu.um, _PT.HField: spu.ampere / spu.um, # Luminal # Optics - _PT.PoyntingVector: spu.watt / spu.um**2, ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py index 21aecc6..4240d60 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py @@ -17,15 +17,15 @@ """Implements `ExtractDataNode`.""" import enum +import functools import typing as typ import bpy -import jax -import numpy as np +import jax.numpy as jnp import sympy.physics.units as spu import tidy3d as td -from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import bl_cache, logger, sim_symbols from blender_maxwell.utils import extra_sympy_units as spux from ... import contracts as ct @@ -37,6 +37,176 @@ log = logger.get(__name__) TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData +#################### +# - Monitor Label Arrays +#################### +def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]: + """Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`. + + Parameters: + monitor_type: The name of the monitor type, with the 'Data' prefix removed. + """ + monitor_data = sim_data.monitor_data[monitor_name] + monitor_type = monitor_data.type + + match monitor_type: + case 'Field' | 'FieldTime' | 'Mode': + ## TODO: flux, poynting, intensity + return [ + field_component + for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'] + if getattr(monitor_data, field_component, None) is not None + ] + + case 'Permittivity': + return ['eps_xx', 'eps_yy', 'eps_zz'] + + case 'Flux' | 'FluxTime': + return ['flux'] + + case ( + 'FieldProjectionAngle' + | 'FieldProjectionCartesian' + | 'FieldProjectionKSpace' + | 'Diffraction' + ): + return [ + 'Er', + 'Etheta', + 'Ephi', + 'Hr', + 'Htheta', + 'Hphi', + ] + + +def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911 + """Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data.""" + xarr = getattr(monitor_data, monitor_attr, None) + if xarr is None: + return None + + def mk_idx_array(axis: str) -> ct.ArrayFlow: + return ct.ArrayFlow( + values=xarr.get_index(axis).values, + unit=symbols[axis].unit, + is_sorted=True, + ) + + # Compute InfoFlow from XArray + symbols = { + # Cartesian + 'x': sim_symbols.space_x(spu.micrometer), + 'y': sim_symbols.space_y(spu.micrometer), + 'z': sim_symbols.space_z(spu.micrometer), + # Spherical + 'r': sim_symbols.ang_r(spu.micrometer), + 'theta': sim_symbols.ang_theta(spu.radian), + 'phi': sim_symbols.ang_phi(spu.radian), + # Freq|Time + 'f': sim_symbols.freq(spu.hertz), + 't': sim_symbols.t(spu.second), + # Power Flux + 'flux': sim_symbols.flux(spu.watt), + # Cartesian Fields + 'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer), + 'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer), + 'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer), + 'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer), + 'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer), + 'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer), + # Spherical Fields + 'Er': sim_symbols.field_er(spu.volt / spu.micrometer), + 'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer), + 'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer), + 'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer), + 'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer), + 'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer), + # Wavevector + 'ux': sim_symbols.dir_x(spu.watt), + 'uy': sim_symbols.dir_y(spu.watt), + # Diffraction Orders + 'orders_x': sim_symbols.diff_order_x(None), + 'orders_y': sim_symbols.diff_order_y(None), + } + + match monitor_data.type: + case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode': + return ct.InfoFlow( + dims={ + symbols['x']: mk_idx_array('x'), + symbols['y']: mk_idx_array('y'), + symbols['z']: mk_idx_array('z'), + symbols['f']: mk_idx_array('f'), + }, + output=symbols[monitor_attr], + ) + + case 'FieldTime': + return ct.InfoFlow( + dims={ + symbols['x']: mk_idx_array('x'), + symbols['y']: mk_idx_array('y'), + symbols['z']: mk_idx_array('z'), + symbols['t']: mk_idx_array('t'), + }, + output=symbols[monitor_attr], + ) + + case 'Flux': + return ct.InfoFlow( + dims={ + symbols['f']: mk_idx_array('f'), + }, + output=symbols[monitor_attr], + ) + + case 'FluxTime': + return ct.InfoFlow( + dims={ + symbols['t']: mk_idx_array('t'), + }, + output=symbols[monitor_attr], + ) + + case 'FieldProjectionAngle': + return ct.InfoFlow( + dims={ + symbols['r']: mk_idx_array('r'), + symbols['theta']: mk_idx_array('theta'), + symbols['phi']: mk_idx_array('phi'), + symbols['f']: mk_idx_array('f'), + }, + output=symbols[monitor_attr], + ) + + case 'FieldProjectionKSpace': + return ct.InfoFlow( + dims={ + symbols['ux']: mk_idx_array('ux'), + symbols['uy']: mk_idx_array('uy'), + symbols['r']: mk_idx_array('r'), + symbols['f']: mk_idx_array('f'), + }, + output=symbols[monitor_attr], + ) + + case 'Diffraction': + return ct.InfoFlow( + dims={ + symbols['orders_x']: mk_idx_array('orders_x'), + symbols['orders_y']: mk_idx_array('orders_y'), + symbols['f']: mk_idx_array('f'), + }, + output=symbols[monitor_attr], + ) + + return None + + +#################### +# - Node +#################### class ExtractDataNode(base.MaxwellSimNode): """Extract data from sockets for further analysis. @@ -45,31 +215,21 @@ class ExtractDataNode(base.MaxwellSimNode): Monitor Data: Extract `Expr`s from monitor data by-component. Attributes: - extract_filter: Identifier for data to extract from the input. + monitor_attr: Identifier for data to extract from the input. """ node_type = ct.NodeType.ExtractData bl_label = 'Extract' input_socket_sets: typ.ClassVar = { - 'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()}, - 'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, + 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(), } output_socket_sets: typ.ClassVar = { - 'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()}, - 'Monitor Data': {'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func)}, + 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func), } #################### - # - Properties - #################### - extract_filter: enum.StrEnum = bl_cache.BLField( - enum_cb=lambda self, _: self.search_extract_filters(), - cb_depends_on={'sim_data_monitor_nametype', 'monitor_data_type'}, - ) - - #################### - # - Computed: Sim Data + # - Properties: Monitor Name #################### @events.on_value_changed( socket_name='Sim Data', @@ -99,198 +259,49 @@ class ExtractDataNode(base.MaxwellSimNode): @bl_cache.cached_bl_property(depends_on={'sim_data'}) def sim_data_monitor_nametype(self) -> dict[str, str] | None: - """For simulation data, deduces a map from the monitor name to the monitor "type". + """Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed). Return: The name to type of monitors in the simulation data. """ if self.sim_data is not None: return { - monitor_name: monitor_data.type + monitor_name: monitor_data.type.removesuffix('Data') for monitor_name, monitor_data in self.sim_data.monitor_data.items() } return None - #################### - # - Computed Properties: Monitor Data - #################### - @events.on_value_changed( - socket_name='Monitor Data', - input_sockets={'Monitor Data'}, - input_sockets_optional={'Monitor Data': True}, + monitor_name: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_monitor_names(), + cb_depends_on={'sim_data_monitor_nametype'}, ) - def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102 - has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data']) - if has_monitor_data: - self.monitor_data = bl_cache.Signal.InvalidateCache - @bl_cache.cached_bl_property() - def monitor_data(self) -> TDMonitorData | None: - """Extracts the monitor data from the input socket. - - Return: - Either the monitor data, if available, or None. - """ - monitor_data = self._compute_input( - 'Monitor Data', kind=ct.FlowKind.Value, optional=True - ) - has_monitor_data = not ct.FlowSignal.check(monitor_data) - if has_monitor_data: - return monitor_data - - return None - - @bl_cache.cached_bl_property(depends_on={'monitor_data'}) - def monitor_data_type(self) -> str | None: - r"""For monitor data, deduces the monitor "type". - - - **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain. - - **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor. - - **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain. - For planes, an explicit direction is defined. - For volumes, the the integral of all outgoing energy is stored. - - **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection. - - **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order. + def search_monitor_names(self) -> list[ct.BLEnumElement]: + """Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`. Notes: - Should be invalidated with (before) `self.monitor_data_attrs`. - - Return: - The "type" of the monitor, if available, else None. - """ - if self.monitor_data is not None: - return self.monitor_data.type.removesuffix('Data') - - return None - - @bl_cache.cached_bl_property(depends_on={'monitor_data_type'}) - def monitor_data_attrs(self) -> list[str] | None: - r"""For monitor data, deduces the valid data-containing attributes. - - The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups. - - - **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor. - - **Permittivity**: Specifically `['xx', 'yy', 'zz']`. - - **Flux(Time)**: Only `['flux']`. - - **FieldProjection(...)**: All of $r$, $\theta$, $\phi$ for both `E` and `H`. - - **Diffraction**: Same as `FieldProjection`. - - Notes: - Should be invalidated after with `self.monitor_data_type`. - - Return: - The "type" of the monitor, if available, else None. - """ - if self.monitor_data is not None: - # Field/FieldTime - if self.monitor_data_type in ['Field', 'FieldTime']: - return [ - field_component - for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'] - if hasattr(self.monitor_data, field_component) - ] - - # Permittivity - if self.monitor_data_type == 'Permittivity': - return ['xx', 'yy', 'zz'] - - # Flux/FluxTime - if self.monitor_data_type in ['Flux', 'FluxTime']: - return ['flux'] - - # FieldProjection(Angle/Cartesian/KSpace)/Diffraction - if self.monitor_data_type in [ - 'FieldProjectionAngle', - 'FieldProjectionCartesian', - 'FieldProjectionKSpace', - 'Diffraction', - ]: - return [ - 'Er', - 'Etheta', - 'Ephi', - 'Hr', - 'Htheta', - 'Hphi', - ] - - return None - - #################### - # - Extraction Filter Search - #################### - def search_extract_filters(self) -> list[ct.BLEnumElement]: - """Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`. - - Notes: - Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`. + Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`. See `bl_cache.BLField` for more on dynamic `EnumProperty`. Returns: - Valid `self.extract_filter` in a format compatible with dynamic `EnumProperty`. + Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`. """ if self.sim_data_monitor_nametype is not None: return [ - (monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i) + ( + monitor_name, + monitor_name, + monitor_type + ' Monitor Data', + '', + i, + ) for i, (monitor_name, monitor_type) in enumerate( self.sim_data_monitor_nametype.items() ) ] - if self.monitor_data_attrs is not None: - # Field/FieldTime - if self.monitor_data_type in ['Field', 'FieldTime']: - return [ - ( - monitor_attr, - monitor_attr, - f'ℂ {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field', - '', - i, - ) - for i, monitor_attr in enumerate(self.monitor_data_attrs) - ] - - # Permittivity - if self.monitor_data_type == 'Permittivity': - return [ - (monitor_attr, monitor_attr, f'ℂ ε_{monitor_attr}', '', i) - for i, monitor_attr in enumerate(self.monitor_data_attrs) - ] - - # Flux/FluxTime - if self.monitor_data_type in ['Flux', 'FluxTime']: - return [ - ( - monitor_attr, - monitor_attr, - 'Power flux integral through the plane / out of the volume', - '', - i, - ) - for i, monitor_attr in enumerate(self.monitor_data_attrs) - ] - - # FieldProjection(Angle/Cartesian/KSpace)/Diffraction - if self.monitor_data_type in [ - 'FieldProjectionAngle', - 'FieldProjectionCartesian', - 'FieldProjectionKSpace', - 'Diffraction', - ]: - return [ - ( - monitor_attr, - monitor_attr, - f'ℂ {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field', - '', - i, - ) - for i, monitor_attr in enumerate(self.monitor_data_attrs) - ] - return [] #################### @@ -303,10 +314,9 @@ class ExtractDataNode(base.MaxwellSimNode): Called by Blender to determine the text to place in the node's header. """ has_sim_data = self.sim_data_monitor_nametype is not None - has_monitor_data = self.monitor_data_attrs is not None - if has_sim_data or has_monitor_data: - return f'Extract: {self.extract_filter}' + if has_sim_data: + return f'Extract: {self.monitor_name}' return self.bl_label @@ -316,336 +326,115 @@ class ExtractDataNode(base.MaxwellSimNode): Parameters: col: UI target for drawing. """ - col.prop(self, self.blfields['extract_filter'], text='') + col.prop(self, self.blfields['monitor_name'], text='') #################### - # - FlowKind.Value: Sim Data -> Monitor Data + # - FlowKind.Func #################### @events.computes_output_socket( - 'Monitor Data', - kind=ct.FlowKind.Value, - # Loaded - props={'extract_filter'}, - input_sockets={'Sim Data'}, - input_sockets_optional={'Sim Data': True}, - ) - def compute_monitor_data( - self, props: dict, input_sockets: dict - ) -> TDMonitorData | ct.FlowSignal: - """Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`. - - Returns: - Monitor data, if available, else `ct.FlowSignal.FlowPending`. - """ - extract_filter = props['extract_filter'] - sim_data = input_sockets['Sim Data'] - has_sim_data = not ct.FlowSignal.check(sim_data) - - if has_sim_data and extract_filter is not None: - return sim_data.monitor_data[extract_filter] - - return ct.FlowSignal.FlowPending - - #################### - # - FlowKind.Array|Func: Monitor Data -> Expr - #################### - @events.computes_output_socket( - 'Expr', - kind=ct.FlowKind.Array, - # Loaded - props={'extract_filter'}, - input_sockets={'Monitor Data'}, - input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, - input_sockets_optional={'Monitor Data': True}, - ) - def compute_expr( - self, props: dict, input_sockets: dict - ) -> jax.Array | ct.FlowSignal: - """Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it. - - Uses the internal `xarray` data returned by Tidy3D. - By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy. - - Returns: - The data array, if available, else `ct.FlowSignal.FlowPending`. - """ - extract_filter = props['extract_filter'] - monitor_data = input_sockets['Monitor Data'] - has_monitor_data = not ct.FlowSignal.check(monitor_data) - - if has_monitor_data and extract_filter is not None: - xarray_data = getattr(monitor_data, extract_filter) - return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None) - - return ct.FlowSignal.FlowPending - - @events.computes_output_socket( - # Trigger 'Expr', kind=ct.FlowKind.Func, # Loaded - output_sockets={'Expr'}, - output_socket_kinds={'Expr': ct.FlowKind.Array}, - output_sockets_optional={'Expr': True}, + props={'monitor_name'}, + input_sockets={'Sim Data'}, + input_socket_kinds={'Sim Data': ct.FlowKind.Value}, ) - def compute_extracted_data_lazy(self, output_sockets: dict) -> ct.FuncFlow | None: - """Declare `Expr:Func` by creating a simple function that directly wraps `Expr:Array`. + def compute_expr( + self, props: dict, input_sockets: dict + ) -> ct.FuncFlow | ct.FlowSignal: + sim_data = input_sockets['Sim Data'] + monitor_name = props['monitor_name'] - Returns: - The composable function array, if available, else `ct.FlowSignal.FlowPending`. - """ - output_expr = output_sockets['Expr'] - has_output_expr = not ct.FlowSignal.check(output_expr) + has_sim_data = not ct.FlowSignal.check(sim_data) - if has_output_expr: - return ct.FuncFlow(func=lambda: output_expr.values, supports_jax=True) + if has_sim_data and monitor_name is not None: + monitor_data = sim_data.get(monitor_name) + if monitor_data is not None: + # Extract Valid Index Labels + ## -> The first output axis will be integer-indexed. + ## -> Each integer will have a string label. + ## -> Those string labels explain the integer as ex. Ex, Ey, Hy. + idx_labels = valid_monitor_attrs(sim_data, monitor_name) + # Generate FuncFlow Per Index Label + ## -> We extract each XArray as an attribute of monitor_data. + ## -> We then bind its values into a unique func_flow. + ## -> This lets us 'stack' then all along the first axis. + func_flows = [] + for idx_label in idx_labels: + xarr = getattr(monitor_data, idx_label) + func_flows.append( + ct.FuncFlow( + func=lambda xarr=xarr: xarr.values, + supports_jax=True, + ) + ) + + # Concatenate and Stack Unified FuncFlow + ## -> First, 'reduce' lets us __or__ all the FuncFlows together. + ## -> Then, 'compose_within' lets us stack them along axis=0. + ## -> The "new" axis=0 is int-indexed axis w/idx_labels labels! + return functools.reduce(lambda a, b: a | b, func_flows).compose_within( + enclosing_func=lambda data: jnp.stack(data, axis=0) + ) + return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending #################### - # - FlowKind.Params: Monitor Data -> Expr + # - FlowKind.Params #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Params, + input_sockets={'Sim Data'}, + input_socket_kinds={'Sim Data': ct.FlowKind.Params}, ) - def compute_data_params(self) -> ct.ParamsFlow: + def compute_data_params(self, input_sockets) -> ct.ParamsFlow: """Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline. Returns: A completely empty `ParamsFlow`, ready to be composed. """ + sim_params = input_sockets['Sim Data'] + has_sim_params = not ct.FlowSignal.check(sim_params) + + if has_sim_params: + return sim_params return ct.ParamsFlow() #################### - # - FlowKind.Info: Monitor Data -> Expr + # - FlowKind.Info #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Info, # Loaded - props={'monitor_data_type', 'extract_filter'}, - input_sockets={'Monitor Data'}, - input_socket_kinds={'Monitor Data': ct.FlowKind.Value}, - input_sockets_optional={'Monitor Data': True}, + props={'monitor_name'}, + input_sockets={'Sim Data'}, + input_socket_kinds={'Sim Data': ct.FlowKind.Value}, ) - def compute_extracted_data_info( - self, props: dict, input_sockets: dict - ) -> ct.InfoFlow: + def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow: """Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type. Returns: Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`. """ - monitor_data = input_sockets['Monitor Data'] - monitor_data_type = props['monitor_data_type'] - extract_filter = props['extract_filter'] + sim_data = input_sockets['Sim Data'] + monitor_name = props['monitor_name'] - has_monitor_data = not ct.FlowSignal.check(monitor_data) + has_sim_data = not ct.FlowSignal.check(sim_data) - # Edge Case: Dangling 'flux' Access on 'FieldMonitor' - ## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields. - ## -> We don't allow 'flux' attribute access, but it can dangle. - ## -> (The method is called when updating each depschain component.) - if monitor_data_type == 'Field' and extract_filter == 'flux': + if not has_sim_data or monitor_name is None: return ct.FlowSignal.FlowPending - # Retrieve XArray - if has_monitor_data and extract_filter is not None: - xarr = getattr(monitor_data, extract_filter, None) - if xarr is None: - return ct.FlowSignal.FlowPending - else: - return ct.FlowSignal.FlowPending + # Extract Data + ## -> All monitor_data. have the exact same InfoFlow. + ## -> So, just construct an InfoFlow w/prepended labelled dimension. + monitor_data = sim_data.get(monitor_name) + idx_labels = valid_monitor_attrs(sim_data, monitor_name) + info = extract_info(monitor_data, idx_labels[0]) - # Compute InfoFlow from XArray - ## XYZF: Field / Permittivity / FieldProjectionCartesian - if monitor_data_type in { - 'Field', - 'Permittivity', - #'FieldProjectionCartesian', - }: - return ct.InfoFlow( - dim_names=['x', 'y', 'z', 'f'], - dim_idx={ - axis: ct.ArrayFlow( - values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True - ) - for axis in ['x', 'y', 'z'] - } - | { - 'f': ct.ArrayFlow( - values=xarr.get_index('f').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Complex, - output_unit=( - spu.volt / spu.micrometer if monitor_data_type == 'Field' else None - ), - ) - - ## XYZT: FieldTime - if monitor_data_type == 'FieldTime': - return ct.InfoFlow( - dim_names=['x', 'y', 'z', 't'], - dim_idx={ - axis: ct.ArrayFlow( - values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True - ) - for axis in ['x', 'y', 'z'] - } - | { - 't': ct.ArrayFlow( - values=xarr.get_index('t').values, - unit=spu.second, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Complex, - output_unit=( - spu.volt / spu.micrometer if monitor_data_type == 'Field' else None - ), - ) - - ## F: Flux - if monitor_data_type == 'Flux': - return ct.InfoFlow( - dim_names=['f'], - dim_idx={ - 'f': ct.ArrayFlow( - values=xarr.get_index('f').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Real, - output_unit=spu.watt, - ) - - ## T: FluxTime - if monitor_data_type == 'FluxTime': - return ct.InfoFlow( - dim_names=['t'], - dim_idx={ - 't': ct.ArrayFlow( - values=xarr.get_index('t').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Real, - output_unit=spu.watt, - ) - - ## RThetaPhiF: FieldProjectionAngle - if monitor_data_type == 'FieldProjectionAngle': - return ct.InfoFlow( - dim_names=['r', 'theta', 'phi', 'f'], - dim_idx={ - 'r': ct.ArrayFlow( - values=xarr.get_index('r').values, - unit=spu.micrometer, - is_sorted=True, - ), - } - | { - c: ct.ArrayFlow( - values=xarr.get_index(c).values, - unit=spu.radian, - is_sorted=True, - ) - for c in ['r', 'theta', 'phi'] - } - | { - 'f': ct.ArrayFlow( - values=xarr.get_index('f').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Real, - output_unit=( - spu.volt / spu.micrometer - if extract_filter.startswith('E') - else spu.ampere / spu.micrometer - ), - ) - - ## UxUyRF: FieldProjectionKSpace - if monitor_data_type == 'FieldProjectionKSpace': - return ct.InfoFlow( - dim_names=['ux', 'uy', 'r', 'f'], - dim_idx={ - c: ct.ArrayFlow( - values=xarr.get_index(c).values, unit=None, is_sorted=True - ) - for c in ['ux', 'uy'] - } - | { - 'r': ct.ArrayFlow( - values=xarr.get_index('r').values, - unit=spu.micrometer, - is_sorted=True, - ), - 'f': ct.ArrayFlow( - values=xarr.get_index('f').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Real, - output_unit=( - spu.volt / spu.micrometer - if extract_filter.startswith('E') - else spu.ampere / spu.micrometer - ), - ) - - ## OrderxOrderyF: Diffraction - if monitor_data_type == 'Diffraction': - return ct.InfoFlow( - dim_names=['orders_x', 'orders_y', 'f'], - dim_idx={ - f'orders_{c}': ct.ArrayFlow( - values=xarr.get_index(f'orders_{c}').values, - unit=None, - is_sorted=True, - ) - for c in ['x', 'y'] - } - | { - 'f': ct.ArrayFlow( - values=xarr.get_index('f').values, - unit=spu.hertz, - is_sorted=True, - ), - }, - output_name=extract_filter, - output_shape=None, - output_mathtype=spux.MathType.Real, - output_unit=( - spu.volt / spu.micrometer - if extract_filter.startswith('E') - else spu.ampere / spu.micrometer - ), - ) - - msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"' - raise RuntimeError(msg) + return info.prepend_dim(sim_symbols.idx, idx_labels) #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py index 4dc9be9..f2f73b1 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py @@ -98,29 +98,29 @@ class FilterOperation(enum.StrEnum): operations = [] # Slice - if info.dim_names: + if info.dims: operations.append(FO.SliceIdx) # Pin ## PinLen1 ## -> There must be a dimension with length 1. - if 1 in list(info.dim_lens.values()): + if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]: operations.append(FO.PinLen1) ## Pin | PinIdx ## -> There must be a dimension, full stop. - if info.dim_names: + if info.dims: operations += [FO.Pin, FO.PinIdx] # Reinterpret ## Swap ## -> There must be at least two dimensions. - if len(info.dim_names) >= 2: # noqa: PLR2004 + if len(info.dims) >= 2: # noqa: PLR2004 operations.append(FO.Swap) ## SetDim ## -> There must be a dimension to correct. - if info.dim_names: + if info.dims: operations.append(FO.SetDim) return operations @@ -158,33 +158,33 @@ class FilterOperation(enum.StrEnum): def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]: FO = FilterOperation match self: - case FO.SliceIdx: - return info.dim_names + case FO.SliceIdx | FO.Swap: + return info.dims # PinLen1: Only allow dimensions with length=1. case FO.PinLen1: return [ - dim_name - for dim_name in info.dim_names - if info.dim_lens[dim_name] == 1 + dim + for dim, dim_idx in info.dims.items() + if dim_idx is not None and len(dim_idx) == 1 ] - # Pin: Only allow dimensions with known indexing. - case FO.Pin: + # Pin: Only allow dimensions with discrete index. + ## TODO: Shouldn't 'Pin' be allowed to index continuous indices too? + case FO.Pin | FO.PinIdx: return [ - dim_name - for dim_name in info.dim_names - if info.dim_has_coords[dim_name] != 0 + dim + for dim, dim_idx in info.dims + if dim_idx is not None and len(dim_idx) > 0 ] - case FO.PinIdx | FO.Swap: - return info.dim_names - case FO.SetDim: return [ - dim_name - for dim_name in info.dim_names - if info.dim_mathtypes[dim_name] == spux.MathType.Integer + dim + for dim, dim_idx in info.dims + if dim_idx is not None + and not isinstance(dim_idx, list) + and dim_idx.mathtype == spux.MathType.Integer ] return [] @@ -224,22 +224,22 @@ class FilterOperation(enum.StrEnum): def transform_info( self, info: ct.InfoFlow, - dim_0: str, - dim_1: str, + dim_0: sim_symbols.SimSymbol, + dim_1: sim_symbols.SimSymbol, + pin_idx: int | None = None, slice_tuple: tuple[int, int, int] | None = None, - corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] - | None = None, + replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None, ): FO = FilterOperation return { FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple), # Pin - FO.PinLen1: lambda: info.delete_dimension(dim_0), - FO.Pin: lambda: info.delete_dimension(dim_0), - FO.PinIdx: lambda: info.delete_dimension(dim_0), + FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), + FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx), # Reinterpret FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1), - FO.SetDim: lambda: info.replace_dim(*corrected_dim), + FO.SetDim: lambda: info.replace_dim(*replaced_dim), }[self]() @@ -318,11 +318,11 @@ class FilterMathNode(base.MaxwellSimNode): #################### # - Properties: Dimension Selection #################### - dim_0: enum.StrEnum = bl_cache.BLField( + active_dim_0: enum.StrEnum = bl_cache.BLField( enum_cb=lambda self, _: self.search_dims(), cb_depends_on={'operation', 'expr_info'}, ) - dim_1: enum.StrEnum = bl_cache.BLField( + active_dim_1: enum.StrEnum = bl_cache.BLField( enum_cb=lambda self, _: self.search_dims(), cb_depends_on={'operation', 'expr_info'}, ) @@ -335,40 +335,23 @@ class FilterMathNode(base.MaxwellSimNode): ] return [] + @bl_cache.cached_bl_property(depends_on={'active_dim_0'}) + def dim_0(self) -> sim_symbols.SimSymbol | None: + if self.expr_info is not None and self.active_dim_0 is not None: + return self.expr_info.dim_by_name(self.active_dim_0) + return None + + @bl_cache.cached_bl_property(depends_on={'active_dim_1'}) + def dim_1(self) -> sim_symbols.SimSymbol | None: + if self.expr_info is not None and self.active_dim_1 is not None: + return self.expr_info.dim_by_name(self.active_dim_1) + return None + #################### # - Properties: Slice #################### slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1]) - #################### - # - Properties: Unit - #################### - set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField( - sim_symbols.CommonSimSymbol.X - ) - - set_dim_active_unit: enum.StrEnum = bl_cache.BLField( - enum_cb=lambda self, _: self.search_valid_units(), - cb_depends_on={'set_dim_symbol'}, - ) - - def search_valid_units(self) -> list[ct.BLEnumElement]: - """Compute Blender enum elements of valid units for the current `physical_type`.""" - physical_type = self.set_dim_symbol.sim_symbol.physical_type - if physical_type is not spux.PhysicalType.NonPhysical: - return [ - (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i) - for i, unit in enumerate(physical_type.valid_units) - ] - return [] - - @bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'}) - def set_dim_unit(self) -> spux.Unit | None: - if self.set_dim_active_unit is not None: - return spux.unit_str_to_unit(self.set_dim_active_unit) - - return None - #################### # - UI #################### @@ -378,27 +361,27 @@ class FilterMathNode(base.MaxwellSimNode): # Slice case FO.SliceIdx: slice_str = ':'.join([str(v) for v in self.slice_tuple]) - return f'Filter: {self.dim_0}[{slice_str}]' + return f'Filter: {self.active_dim_0}[{slice_str}]' # Pin case FO.PinLen1: - return f'Filter: Pin {self.dim_0}[0]' + return f'Filter: Pin {self.active_dim_0}[0]' case FO.Pin: - return f'Filter: Pin {self.dim_0}[...]' + return f'Filter: Pin {self.active_dim_0}[...]' case FO.PinIdx: pin_idx_axis = self._compute_input( 'Axis', kind=ct.FlowKind.Value, optional=True ) has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis) if has_pin_idx_axis: - return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]' + return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]' return self.bl_label # Reinterpret case FO.Swap: - return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]' + return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]' case FO.SetDim: - return f'Filter: Set [{self.dim_0}]' + return f'Filter: Set [{self.active_dim_0}]' case _: return self.bl_label @@ -409,20 +392,15 @@ class FilterMathNode(base.MaxwellSimNode): if self.operation is not None: match self.operation.num_dim_inputs: case 1: - layout.prop(self, self.blfields['dim_0'], text='') + layout.prop(self, self.blfields['active_dim_0'], text='') case 2: row = layout.row(align=True) - row.prop(self, self.blfields['dim_0'], text='') - row.prop(self, self.blfields['dim_1'], text='') + row.prop(self, self.blfields['active_dim_0'], text='') + row.prop(self, self.blfields['active_dim_1'], text='') if self.operation is FilterOperation.SliceIdx: layout.prop(self, self.blfields['slice_tuple'], text='') - if self.operation is FilterOperation.SetDim: - row = layout.row(align=True) - row.prop(self, self.blfields['set_dim_symbol'], text='') - row.prop(self, self.blfields['set_dim_active_unit'], text='') - #################### # - Events #################### @@ -450,50 +428,47 @@ class FilterMathNode(base.MaxwellSimNode): if not has_info: return - # Pin Dim by-Value: Synchronize Input Socket - ## -> The user will be given a socket w/correct mathtype, unit, etc. . - ## -> Internally, this value will map to a particular index. - if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None: - # Deduce Pinned Information - pinned_unit = info.dim_units[props['dim_0']] - pinned_mathtype = info.dim_mathtypes[props['dim_0']] - pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit) - wanted_mathtype = ( - spux.MathType.Complex - if pinned_mathtype == spux.MathType.Complex - and spux.MathType.Complex in pinned_physical_type.valid_mathtypes - else spux.MathType.Real - ) + dim_0 = props['dim_0'] - # Get Current and Wanted Socket Defs - ## -> 'Value' may already exist. If not, all is well. + # Loose Sockets: Pin Dim by-Value + ## -> Works with continuous / discrete indexes. + ## -> The user will be given a socket w/correct mathtype, unit, etc. . + if ( + props['operation'] is FilterOperation.Pin + and dim_0 is not None + and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0)) + ): + dim = dim_0 current_bl_socket = self.loose_input_sockets.get('Value') - # Determine Whether to Construct - ## -> If nothing needs to change, then nothing changes. if ( current_bl_socket is None + or current_bl_socket.active_kind != ct.FlowKind.Value or current_bl_socket.size is not spux.NumberSize1D.Scalar - or current_bl_socket.physical_type != pinned_physical_type - or current_bl_socket.mathtype != wanted_mathtype + or current_bl_socket.physical_type != dim.physical_type + or current_bl_socket.mathtype != dim.mathtype ): self.loose_input_sockets = { 'Value': sockets.ExprSocketDef( active_kind=ct.FlowKind.Value, - physical_type=pinned_physical_type, - mathtype=wanted_mathtype, - default_unit=pinned_unit, + physical_type=dim.physical_type, + mathtype=dim.mathtype, + default_unit=dim.unit, ), } - # Pin Dim by-Index: Synchronize Input Socket - ## -> The user will be given a simple integer socket. + # Loose Sockets: Pin Dim by-Value + ## -> Works with discrete points / labelled integers. elif ( - props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None + props['operation'] is FilterOperation.PinIdx + and dim_0 is not None + and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0)) ): + dim = dim_0 current_bl_socket = self.loose_input_sockets.get('Axis') if ( current_bl_socket is None + or current_bl_socket.active_kind != ct.FlowKind.Value or current_bl_socket.size is not spux.NumberSize1D.Scalar or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical or current_bl_socket.mathtype != spux.MathType.Integer @@ -505,28 +480,26 @@ class FilterMathNode(base.MaxwellSimNode): ) } - # Set Dim: Synchronize Input Socket + # Loose Sockets: Set Dim ## -> The user must provide a (ℤ) -> ℝ array. ## -> It must be of identical length to the replaced axis. - elif ( - props['operation'] is FilterOperation.SetDim - and props['dim_0'] is not None - and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer - and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical - ): - # Deduce Axis Information + elif props['operation'] is FilterOperation.SetDim and dim_0 is not None: + dim = dim_0 current_bl_socket = self.loose_input_sockets.get('Dim') if ( current_bl_socket is None or current_bl_socket.active_kind != ct.FlowKind.Func - or current_bl_socket.mathtype != spux.MathType.Real - or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical + or current_bl_socket.size is not spux.NumberSize1D.Scalar + or current_bl_socket.mathtype != dim.mathtype + or current_bl_socket.physical_type != dim.physical_type ): self.loose_input_sockets = { 'Dim': sockets.ExprSocketDef( active_kind=ct.FlowKind.Func, - mathtype=spux.MathType.Real, - physical_type=spux.PhysicalType.NonPhysical, + physical_type=dim.physical_type, + mathtype=dim.mathtype, + default_unit=dim.unit, + show_func_ui=False, show_info_columns=True, ) } @@ -553,25 +526,20 @@ class FilterMathNode(base.MaxwellSimNode): has_lazy_func = not ct.FlowSignal.check(lazy_func) has_info = not ct.FlowSignal.check(info) - # Dimension(s) dim_0 = props['dim_0'] dim_1 = props['dim_1'] + slice_tuple = props['slice_tuple'] if ( has_lazy_func and has_info and operation is not None and operation.are_dims_valid(info, dim_0, dim_1) ): - axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None - axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None - slice_tuple = ( - props['slice_tuple'] - if self.operation is FilterOperation.SliceIdx - else None - ) + axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None + axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None return lazy_func.compose_within( - operation.jax_func(axis_0, axis_1, slice_tuple), + operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple), enclosing_func_args=operation.func_args, supports_jax=True, ) @@ -588,8 +556,6 @@ class FilterMathNode(base.MaxwellSimNode): 'dim_1', 'operation', 'slice_tuple', - 'set_dim_symbol', - 'set_dim_active_unit', }, input_sockets={'Expr', 'Dim'}, input_socket_kinds={ @@ -601,14 +567,15 @@ class FilterMathNode(base.MaxwellSimNode): def compute_info(self, props, input_sockets) -> ct.InfoFlow: operation = props['operation'] info = input_sockets['Expr'] - dim_coords = input_sockets['Dim'][ct.FlowKind.Func] - dim_params = input_sockets['Dim'][ct.FlowKind.Params] - dim_info = input_sockets['Dim'][ct.FlowKind.Info] - dim_symbol = props['set_dim_symbol'] - dim_active_unit = props['set_dim_active_unit'] has_info = not ct.FlowSignal.check(info) - has_dim_coords = not ct.FlowSignal.check(dim_coords) + + # Dim (Op.SetDim) + dim_func = input_sockets['Dim'][ct.FlowKind.Func] + dim_params = input_sockets['Dim'][ct.FlowKind.Params] + dim_info = input_sockets['Dim'][ct.FlowKind.Info] + + has_dim_func = not ct.FlowSignal.check(dim_func) has_dim_params = not ct.FlowSignal.check(dim_params) has_dim_info = not ct.FlowSignal.check(dim_info) @@ -619,44 +586,42 @@ class FilterMathNode(base.MaxwellSimNode): if has_info and operation is not None: # Set Dimension: Retrieve Array if props['operation'] is FilterOperation.SetDim: + new_dim = ( + next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None + ) + if ( dim_0 is not None - # Check Replaced Dimension - and has_dim_coords - and len(dim_coords.func_args) == 1 - and dim_coords.func_args[0] is spux.MathType.Integer - and not dim_coords.func_kwargs - and dim_coords.supports_jax - # Check Params - and has_dim_params - and len(dim_params.func_args) == 1 - and not dim_params.func_kwargs - # Check Info + and new_dim is not None and has_dim_info + and has_dim_params + # Check New Dimension Index Array Sizing + and len(dim_info.dims) == 1 + and dim_info.output.rows == 1 + and dim_info.output.cols == 1 + # Check Lack of Params Symbols + and not dim_params.symbols + # Check Expr Dim | New Dim Compatibility + and info.has_idx_discrete(dim_0) + and dim_info.has_idx_discrete(new_dim) + and len(info.dims[dim_0]) == len(dim_info.dims[new_dim]) ): # Retrieve Dimension Coordinate Array ## -> It must be strictly compatible. - values = dim_coords.func_jax(int(dim_params.func_args[0])) - if ( - len(values.shape) != 1 - or values.shape[0] != info.dim_lens[dim_0] - ): - return ct.FlowSignal.FlowPending + values = dim_func.realize(dim_params, spux.UNITS_SI) # Transform Info w/Corrected Dimension ## -> The existing dimension will be replaced. - if dim_active_unit is not None: - dim_unit = spux.unit_str_to_unit(dim_active_unit) - else: - dim_unit = None - new_dim_idx = ct.ArrayFlow( values=values, - unit=dim_unit, - ) - corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)] + unit=spux.convert_to_unit_system( + dim_info.output.unit, spux.UNITS_SI + ), + ).rescale_to_unit(dim_info.output.unit) + + replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)] return operation.transform_info( - info, dim_0, dim_1, corrected_dim=corrected_dim + info, dim_0, dim_1, replaced_dim=replaced_dim ) return ct.FlowSignal.FlowPending return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple) @@ -702,7 +667,7 @@ class FilterMathNode(base.MaxwellSimNode): # Pin by-Value: Compute Nearest IDX ## -> Presume a sorted index array to be able to use binary search. if props['operation'] is FilterOperation.Pin and has_pinned_value: - nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of( + nearest_idx_to_value = info.dims[dim_0].nearest_idx_of( pinned_value, require_sorted=True ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 742fe6e..29a66da 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -23,7 +23,7 @@ import bpy import jax.numpy as jnp import sympy as sp -from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import bl_cache, logger, sim_symbols from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct @@ -153,40 +153,38 @@ class MapOperation(enum.StrEnum): # - Ops from Shape #################### @staticmethod - def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: + def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]: + ## TODO: By info, not shape. + ## TODO: Check valid domains/mathtypes for some functions. MO = MapOperation + element_ops = [ + MO.Real, + MO.Imag, + MO.Abs, + MO.Sq, + MO.Sqrt, + MO.InvSqrt, + MO.Cos, + MO.Sin, + MO.Tan, + MO.Acos, + MO.Asin, + MO.Atan, + MO.Sinc, + ] - match shape: - case 'noshape': - return [] + match (info.output.rows, info.output.cols): + case (1, 1): + return element_ops - # By Number - case None: - return [ - MO.Real, - MO.Imag, - MO.Abs, - MO.Sq, - MO.Sqrt, - MO.InvSqrt, - MO.Cos, - MO.Sin, - MO.Tan, - MO.Acos, - MO.Asin, - MO.Atan, - MO.Sinc, - ] + case (_, 1): + return [*element_ops, MO.Norm2] - match len(shape): - # By Vector - case 1: - return [ - MO.Norm2, - ] - # By Matrix - case 2: + case (rows, cols) if rows == cols: + ## TODO: Check hermitian/posdef for cholesky. + ## - Can we even do this with just the output symbol approach? return [ + *element_ops, MO.Det, MO.Cond, MO.NormFro, @@ -201,6 +199,18 @@ class MapOperation(enum.StrEnum): MO.Svd, ] + case (rows, cols): + return [ + *element_ops, + MO.Cond, + MO.NormFro, + MO.Rank, + MO.SvdVals, + MO.Inv, + MO.Tra, + MO.Svd, + ] + return [] #################### @@ -288,41 +298,76 @@ class MapOperation(enum.StrEnum): def transform_info(self, info: ct.InfoFlow): MO = MapOperation + return { # By Number - MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real), - MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real), - MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real), + MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real), + MO.Sq: lambda: info, + MO.Sqrt: lambda: info, + MO.InvSqrt: lambda: info, + MO.Cos: lambda: info, + MO.Sin: lambda: info, + MO.Tan: lambda: info, + MO.Acos: lambda: info, + MO.Asin: lambda: info, + MO.Atan: lambda: info, + MO.Sinc: lambda: info, # By Vector - MO.Norm2: lambda: info.collapse_output( - collapsed_name=MO.to_name(self).replace('v', info.output_name), - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=info.output_unit, + MO.Norm2: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), ), # By Matrix - MO.Det: lambda: info.collapse_output( - collapsed_name=MO.to_name(self).replace('V', info.output_name), - collapsed_mathtype=info.output_mathtype, - collapsed_unit=info.output_unit, + MO.Det: lambda: info.update_output( + rows=1, + cols=1, ), - MO.Cond: lambda: info.collapse_output( - collapsed_name=MO.to_name(self).replace('V', info.output_name), - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=None, + MO.Cond: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, ), - MO.NormFro: lambda: info.collapse_output( - collapsed_name=MO.to_name(self).replace('V', info.output_name), - collapsed_mathtype=spux.MathType.Real, - collapsed_unit=info.output_unit, + MO.NormFro: lambda: info.update_output( + mathtype=spux.MathType.Real, + rows=1, + cols=1, + # Interval + interval_finite_re=(0, sim_symbols.float_max), + interval_inf=(False, True), + interval_closed=(True, False), ), - MO.Rank: lambda: info.collapse_output( - collapsed_name=MO.to_name(self).replace('V', info.output_name), - collapsed_mathtype=spux.MathType.Integer, - collapsed_unit=None, + MO.Rank: lambda: info.update_output( + mathtype=spux.MathType.Integer, + rows=1, + cols=1, + physical_type=spux.PhysicalType.NonPhysical, + unit=None, + # Interval + interval_finite_re=(0, sim_symbols.int_max), + interval_inf=(False, True), + interval_closed=(True, False), ), - ## TODO: Matrix -> Vec - ## TODO: Matrix -> Matrices - }.get(self, lambda: info)() + # Matrix -> Vector ## TODO: ALL OF THESE + MO.Diag: lambda: info, + MO.EigVals: lambda: info, + MO.SvdVals: lambda: info, + # Matrix -> Matrix ## TODO: ALL OF THESE + MO.Inv: lambda: info, + MO.Tra: lambda: info, + # Matrix -> Matrices ## TODO: ALL OF THESE + MO.Qr: lambda: info, + MO.Chol: lambda: info, + MO.Svd: lambda: info, + }[self]() #################### @@ -435,29 +480,26 @@ class MapMathNode(base.MaxwellSimNode): ) if has_info and not info_pending: - self.expr_output_shape = bl_cache.Signal.InvalidateCache + self.expr_info = bl_cache.Signal.InvalidateCache @bl_cache.cached_bl_property() - def expr_output_shape(self) -> ct.InfoFlow | None: + def expr_info(self) -> ct.InfoFlow | None: info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True) has_info = not ct.FlowSignal.check(info) if has_info: - return info.output_shape - - return 'noshape' + return info + return None operation: MapOperation = bl_cache.BLField( enum_cb=lambda self, _: self.search_operations(), - cb_depends_on={'expr_output_shape'}, + cb_depends_on={'expr_info'}, ) def search_operations(self) -> list[ct.BLEnumElement]: - if self.expr_output_shape != 'noshape': + if self.info is not None: return [ operation.bl_enum_element(i) - for i, operation in enumerate( - MapOperation.by_element_shape(self.expr_output_shape) - ) + for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info)) ] return [] @@ -474,7 +516,7 @@ class MapMathNode(base.MaxwellSimNode): layout.prop(self, self.blfields['operation'], text='') #################### - # - FlowKind.Value|Func + # - FlowKind.Value #################### @events.computes_output_socket( 'Expr', @@ -495,6 +537,9 @@ class MapMathNode(base.MaxwellSimNode): return ct.FlowSignal.FlowPending + #################### + # - FlowKind.Func + #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Func, @@ -518,7 +563,7 @@ class MapMathNode(base.MaxwellSimNode): return ct.FlowSignal.FlowPending #################### - # - FlowKind.Info|Params + # - FlowKind.Info #################### @events.computes_output_socket( 'Expr', @@ -538,6 +583,9 @@ class MapMathNode(base.MaxwellSimNode): return ct.FlowSignal.FlowPending + #################### + # - FlowKind.Params + #################### @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Params, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py index b6c6d60..74c4d6d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py @@ -107,32 +107,31 @@ class TransformOperation(enum.StrEnum): # Covariant Transform ## Freq <-> VacWL - for dim_name in info.dim_names: - if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: + for dim in info.dims: + if dim.physical_type == spux.PhysicalType.Freq: operations.append(TO.FreqToVacWL) - if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq: + if dim.physical_type == spux.PhysicalType.Freq: operations.append(TO.VacWLToFreq) # Fold ## (Last) Int Dim (=2) to Complex - if len(info.dim_names) >= 1: - last_dim_name = info.dim_names[-1] - if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004 + if len(info.dims) >= 1: + if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004 operations.append(TO.IntDimToComplex) ## To Vector - if len(info.dim_names) >= 1: + if len(info.dims) >= 1: operations.append(TO.DimToVec) ## To Matrix - if len(info.dim_names) >= 2: # noqa: PLR2004 + if len(info.dims) >= 2: # noqa: PLR2004 operations.append(TO.DimsToMat) # Fourier ## 1D Fourier - if info.dim_names: - last_physical_type = info.dim_physical_types[info.dim_names[-1]] + if info.dims: + last_physical_type = info.last_dim.physical_type if last_physical_type == spux.PhysicalType.Time: operations.append(TO.FFT1D) if last_physical_type == spux.PhysicalType.Freq: @@ -188,15 +187,15 @@ class TransformOperation(enum.StrEnum): unit: spux.Unit | None = None, ) -> ct.InfoFlow | None: TO = TransformOperation - if not info.dim_names: + if not info.dims: return None return { - # Index + # Covariant Transform TO.FreqToVacWL: lambda: info.replace_dim( - (f_dim := info.dim_names[-1]), + (f_dim := info.last_dim), [ - 'wl', - info.dim_idx[f_dim].rescale( + sim_symbols.wl(spu.nanometer), + info.dims[f_dim].rescale( lambda el: sci_constants.vac_speed_of_light / el, reverse=True, new_unit=spu.nanometer, @@ -204,10 +203,10 @@ class TransformOperation(enum.StrEnum): ], ), TO.VacWLToFreq: lambda: info.replace_dim( - (wl_dim := info.dim_names[-1]), + (wl_dim := info.last_dim), [ - 'f', - info.dim_idx[wl_dim].rescale( + sim_symbols.freq(spux.THz), + info.dims[wl_dim].rescale( lambda el: sci_constants.vac_speed_of_light / el, reverse=True, new_unit=spux.THz, @@ -215,24 +214,24 @@ class TransformOperation(enum.StrEnum): ], ), # Fold - TO.IntDimToComplex: lambda: info.delete_dimension( - info.dim_names[-1] - ).set_output_mathtype(spux.MathType.Complex), - TO.DimToVec: lambda: info.shift_last_input, - TO.DimsToMat: lambda: info.shift_last_input.shift_last_input, + TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output( + mathtype=spux.MathType.Complex + ), + TO.DimToVec: lambda: info.fold_last_input(), + TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(), # Fourier TO.FFT1D: lambda: info.replace_dim( - info.dim_names[-1], + info.last_dim, [ - 'f', - ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz), + sim_symbols.freq(spux.THz), + None, ], ), TO.InvFFT1D: info.replace_dim( - info.dim_names[-1], + info.last_dim, [ - 't', - ct.RangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.second), + sim_symbols.t(spu.second), + None, ], ), }.get(self, lambda: info)() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 60101b5..d6a785c 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -38,7 +38,6 @@ class VizMode(enum.StrEnum): **NOTE**: >1D output dimensions currently have no viz. Plots for `() -> ℝ`: - - Hist1D: Bin-summed distribution. - BoxPlot1D: Box-plot describing the distribution. Plots for `(ℤ) -> ℝ`: @@ -61,7 +60,6 @@ class VizMode(enum.StrEnum): - Heatmap3D: Colormapped field with value at each voxel. """ - Hist1D = enum.auto() BoxPlot1D = enum.auto() Curve2D = enum.auto() @@ -78,42 +76,38 @@ class VizMode(enum.StrEnum): @staticmethod def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: - EMPTY = () Z = spux.MathType.Integer R = spux.MathType.Real VM = VizMode - valid_viz_modes = { - (EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D], - ((Z), (None, R)): [ - VM.Hist1D, + return { + ((Z), (1, 1, R)): [ VM.BoxPlot1D, ], - ((R,), (None, R)): [ + ((R,), (1, 1, R)): [ VM.Curve2D, VM.Points2D, VM.Bar, ], - ((R, Z), (None, R)): [ + ((R, Z), (1, 1, R)): [ VM.Curves2D, VM.FilledCurves2D, ], - ((R, R), (None, R)): [ + ((R, R), (1, 1, R)): [ VM.Heatmap2D, ], - ((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D], + ((R, R, R), (1, 1, R)): [ + VM.SqueezedHeatmap2D, + VM.Heatmap3D, + ], }.get( ( - tuple(info.dim_mathtypes.values()), - (info.output_shape, info.output_mathtype), - ) + tuple([dim.mathtype for dim in info.dims.values()]), + (info.output.rows, info.output.cols, info.output.mathtype), + ), + [], ) - if valid_viz_modes is None: - return [] - - return valid_viz_modes - @staticmethod def to_plotter( value: typ.Self, @@ -121,7 +115,6 @@ class VizMode(enum.StrEnum): [jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None ]: return { - VizMode.Hist1D: image_ops.plot_hist_1d, VizMode.BoxPlot1D: image_ops.plot_box_plot_1d, VizMode.Curve2D: image_ops.plot_curve_2d, VizMode.Points2D: image_ops.plot_points_2d, @@ -136,7 +129,6 @@ class VizMode(enum.StrEnum): @staticmethod def to_name(value: typ.Self) -> str: return { - VizMode.Hist1D: 'Histogram', VizMode.BoxPlot1D: 'Box Plot', VizMode.Curve2D: 'Curve', VizMode.Points2D: 'Points', @@ -164,7 +156,6 @@ class VizTarget(enum.StrEnum): @staticmethod def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None: return { - VizMode.Hist1D: [VizTarget.Plot2D], VizMode.BoxPlot1D: [VizTarget.Plot2D], VizMode.Curve2D: [VizTarget.Plot2D], VizMode.Points2D: [VizTarget.Plot2D], @@ -333,35 +324,20 @@ class VizNode(base.MaxwellSimNode): ## -> This happens if Params contains not-yet-realized symbols. if has_info and has_params and params.symbols: if set(self.loose_input_sockets) != { - sym.name for sym in params.symbols if sym.name in info.dim_names + dim.name for dim in params.symbols if dim in info.dims }: self.loose_input_sockets = { - sym.name: sockets.ExprSocketDef( - active_kind=ct.FlowKind.Range, - size=spux.NumberSize1D.Scalar, - mathtype=info.dim_mathtypes[sym.name], - physical_type=info.dim_physical_types[sym.name], - default_min=( - info.dim_idx[sym.name].start - if not sp.S(info.dim_idx[sym.name].start).is_infinite - else sp.S(0) - ), - default_max=( - info.dim_idx[sym.name].start - if not sp.S(info.dim_idx[sym.name].stop).is_infinite - else sp.S(1) - ), - default_steps=50, - ) - for sym in params.sorted_symbols - if sym.name in info.dim_names + dim_name: sockets.ExprSocketDef(**expr_info) + for dim_name, expr_info in params.sym_expr_infos( + info, use_range=True + ).items() } elif self.loose_input_sockets: self.loose_input_sockets = {} ##################### - ## - Plotting + ## - FlowKind.Value ##################### @events.computes_output_socket( 'Preview', @@ -375,8 +351,12 @@ class VizNode(base.MaxwellSimNode): all_loose_input_sockets=True, ) def compute_dummy_value(self, props, input_sockets, loose_input_sockets): + """Needed for the plot to regenerate in the viewer.""" return ct.FlowSignal.NoFlow + ##################### + ## - On Show Plot + ##################### @events.on_show_plot( managed_objs={'plot'}, props={'viz_mode', 'viz_target', 'colormap'}, @@ -384,13 +364,12 @@ class VizNode(base.MaxwellSimNode): input_socket_kinds={ 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params} }, - unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, all_loose_input_sockets=True, stop_propagation=True, ) def on_show_plot( - self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems - ): + self, managed_objs, props, input_sockets, loose_input_sockets + ) -> None: # Retrieve Inputs lazy_func = input_sockets['Expr'][ct.FlowKind.Func] info = input_sockets['Expr'][ct.FlowKind.Info] @@ -399,8 +378,6 @@ class VizNode(base.MaxwellSimNode): has_info = not ct.FlowSignal.check(info) has_params = not ct.FlowSignal.check(params) - # Invalid Mode | Target - ## -> To limit branching, return now if things aren't right. if ( not has_info or not has_params @@ -410,53 +387,42 @@ class VizNode(base.MaxwellSimNode): return # Compute Ranges for Symbols from Loose Sockets - ## -> These are the concrete values of the symbol for plotting. ## -> In a quite nice turn of events, all this is cached lookups. ## -> ...Unless something changed, in which case, well. It changed. - symbol_values = { - sym: ( - loose_input_sockets[sym.name] - .realize_array.rescale_to_unit(info.dim_units[sym.name]) - .values + symbol_array_values = { + sim_syms: ( + loose_input_sockets[sim_syms] + .rescale_to_unit(sim_syms.unit) + .realize_array ) - for sym in params.sorted_symbols + for sim_syms in params.sorted_symbols } - - # Realize Func w/Symbolic Values, Unit System - ## -> This gives us the actual plot data! - data = lazy_func.func_jax( - *params.scaled_func_args( - unit_systems['BlenderUnits'], symbol_values=symbol_values - ), - **params.scaled_func_kwargs( - unit_systems['BlenderUnits'], symbol_values=symbol_values - ), - ) + data = lazy_func.realize(params, symbol_values=symbol_array_values) # Replace InfoFlow Indices w/Realized Symbolic Ranges ## -> This ensures correct axis scaling. if params.symbols: - info = info.rescale_dim_idxs(loose_input_sockets) + info = info.replace_dims(symbol_array_values) - # Visualize by-Target - if props['viz_target'] == VizTarget.Plot2D: - managed_objs['plot'].mpl_plot_to_image( - lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax), - bl_select=True, - ) + match props['viz_target']: + case VizTarget.Plot2D: + managed_objs['plot'].mpl_plot_to_image( + lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax), + bl_select=True, + ) - if props['viz_target'] == VizTarget.Pixels: - managed_objs['plot'].map_2d_to_image( - data, - colormap=props['colormap'], - bl_select=True, - ) + case VizTarget.Pixels: + managed_objs['plot'].map_2d_to_image( + data, + colormap=props['colormap'], + bl_select=True, + ) - if props['viz_target'] == VizTarget.PixelsPlane: - raise NotImplementedError + case VizTarget.PixelsPlane: + raise NotImplementedError - if props['viz_target'] == VizTarget.Voxels: - raise NotImplementedError + case VizTarget.Voxels: + raise NotImplementedError #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py index 389c311..fa56228 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py @@ -67,7 +67,7 @@ ManagedObjName: typ.TypeAlias = str PropName: typ.TypeAlias = str -def event_decorator( +def event_decorator( # noqa: PLR0913 event: ct.FlowEvent, callback_info: EventCallbackInfo | None, stop_propagation: bool = False, @@ -91,31 +91,42 @@ def event_decorator( scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}), ): - """Returns a decorator for a method of `MaxwellSimNode`, declaring it as able respond to events passing through a node. + """Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through. + + Should generally be used via a high-level decorator such as `on_value_changed`. + + For more about how event methods are actually registered and run, please refer to the documentation of `MaxwellSimNode`. Parameters: event: A name describing which event the decorator should respond to. - Set to `return_method.event` callback_info: A dictionary that provides the caller with additional per-`event` information. This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback. - props: Set of `props` to compute, then pass to the decorated method. stop_propagation: Whether or stop propagating the event through the graph after encountering this method. Other methods defined on the same node will still run. managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method. + props: Set of `props` to compute, then pass to the decorated method. input_sockets: Set of `input_sockets` to compute, then pass to the decorated method. + input_sockets_optional: Whether an input socket is required to exist. + When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error. input_socket_kinds: The `ct.FlowKind` to compute per-input-socket. If an input socket isn't specified, it defaults to `ct.FlowKind.Value`. output_sockets: Set of `output_sockets` to compute, then pass to the decorated method. + output_sockets_optional: Whether an output socket is required to exist. + When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error. + output_socket_kinds: The `ct.FlowKind` to compute per-output-socket. + If an output socket isn't specified, it defaults to `ct.FlowKind.Value`. all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method. Used when the names of the loose input sockets are unknown, but all of their values are needed. all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method. Used when the names of the loose output sockets are unknown, but all of their values are needed. + unit_systems: String identifiers under which to load a unit system, made available to the method. + scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system. + This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner. + scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system. + This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner. Returns: - A decorator, which can be applied to a method of `MaxwellSimNode`. - When a `MaxwellSimNode` subclass initializes, such a decorated method will be picked up on. - - When `event` passes through the node, then `callback_info` is used to determine + A decorator, which can be applied to a method of `MaxwellSimNode` to make it an "event method". """ req_params = ( {'self'} @@ -375,7 +386,6 @@ def on_value_changed( ) -## TODO: Change name to 'on_output_requested' def computes_output_socket( output_socket_name: ct.SocketName | None, kind: ct.FlowKind = ct.FlowKind.Value, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py index e22412a..7db23bc 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import enum import typing as typ from pathlib import Path @@ -21,7 +22,7 @@ import bpy import sympy as sp import tidy3d as td -from blender_maxwell.utils import bl_cache, logger +from blender_maxwell.utils import bl_cache, logger, sim_symbols from blender_maxwell.utils import extra_sympy_units as spux from .... import contracts as ct @@ -91,6 +92,88 @@ class DataFileImporterNode(base.MaxwellSimNode): return info return None + #################### + # - Info Guides + #################### + output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName) + output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) + output_physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.NonPhysical + ) + output_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type), + cb_depends_on={'output_physical_type'}, + ) + + dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.LowerA + ) + dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) + dim_0_physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.NonPhysical + ) + dim_0_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type), + cb_depends_on={'dim_0_physical_type'}, + ) + + dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.LowerB + ) + dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) + dim_1_physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.NonPhysical + ) + dim_1_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type), + cb_depends_on={'dim_1_physical_type'}, + ) + + dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.LowerC + ) + dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) + dim_2_physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.NonPhysical + ) + dim_2_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type), + cb_depends_on={'dim_2_physical_type'}, + ) + + dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.LowerD + ) + dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real) + dim_3_physical_type: spux.PhysicalType = bl_cache.BLField( + spux.PhysicalType.NonPhysical + ) + dim_3_unit: enum.StrEnum = bl_cache.BLField( + enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type), + cb_depends_on={'dim_3_physical_type'}, + ) + + def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]: + if physical_type is not spux.PhysicalType.NonPhysical: + return [ + (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i) + for i, unit in enumerate(physical_type.valid_units) + ] + return [] + + def dim(self, i: int): + dim_name = getattr(self, f'dim_{i}_name') + dim_mathtype = getattr(self, f'dim_{i}_mathtype') + dim_physical_type = getattr(self, f'dim_{i}_physical_type') + dim_unit = getattr(self, f'dim_{i}_unit') + + return sim_symbols.SimSymbol( + sym_name=dim_name, + mathtype=dim_mathtype, + physical_type=dim_physical_type, + unit=spux.unit_str_to_unit(dim_unit), + ) + #################### # - UI #################### @@ -118,7 +201,20 @@ class DataFileImporterNode(base.MaxwellSimNode): row.label(text=self.file_path.name) def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: - pass + """Draw loaded properties.""" + for i in range(len(self.expr_info.dims)): + col = layout.column(align=True) + row = col.row(align=True) + row.alignment = 'CENTER' + row.label(text=f'Load Dim {i}') + + row = col.row(align=True) + row.prop(self, self.blfields[f'dim_{i}_name'], text='') + row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='') + + row = col.row(align=True) + row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='') + row.prop(self, self.blfields[f'dim_{i}_unit'], text='') #################### # - FlowKind.Array|Func @@ -174,10 +270,12 @@ class DataFileImporterNode(base.MaxwellSimNode): @events.computes_output_socket( 'Expr', kind=ct.FlowKind.Info, + # Loaded + props={'output_name', 'output_physical_type', 'output_unit'}, output_sockets={'Expr'}, output_socket_kinds={'Expr': ct.FlowKind.Func}, ) - def compute_info(self, output_sockets) -> ct.InfoFlow: + def compute_info(self, props, output_sockets) -> ct.InfoFlow: """Declare an `InfoFlow` based on the data shape. This currently requires computing the data. @@ -196,26 +294,24 @@ class DataFileImporterNode(base.MaxwellSimNode): # Deduce Dimensionality _shape = data.shape shape = _shape if _shape is not None else () - dim_names = [f'a{i}' for i in range(len(shape))] + dim_syms = [self.dim(i) for i in range(len(shape))] # Return InfoFlow - ## -> TODO: How to interpret the data should be user-defined. - ## -> -- This may require those nice dynamic symbols. return ct.InfoFlow( - dim_names=dim_names, ## TODO: User - dim_idx={ - dim_name: ct.RangeFlow( - start=sp.S(0), ## TODO: User - stop=sp.S(shape[i] - 1), ## TODO: User - steps=shape[dim_names.index(dim_name)], - unit=None, ## TODO: User + dims={ + dim_sym: ct.RangeFlow( + start=sp.S(0), + stop=sp.S(shape[i] - 1), + steps=shape[i], + unit=self.dim(i).unit, ) - for i, dim_name in enumerate(dim_names) + for i, dim_sym in enumerate(dim_syms) }, - output_name='_', - output_shape=None, - output_mathtype=spux.MathType.Real, ## TODO: User - output_unit=None, ## TODO: User + output=sim_symbols.SimSymbol( + sym_name=props['output_name'], + mathtype=props['output_mathtype'], + physical_type=props['output_physical_type'], + ), ) return ct.FlowSignal.FlowPending diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py index 892db69..99b7257 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/file_exporters/data_file_exporter.py @@ -229,11 +229,11 @@ class DataFileExporterNode(base.MaxwellSimNode): ## -> Only happens if Params contains not-yet-realized symbols. if has_info and has_params and params.symbols: if set(self.loose_input_sockets) != { - sym.name for sym in params.symbols if sym.name in info.dim_names + dim.name for dim in params.symbols if dim in info.dims }: self.loose_input_sockets = { - sym_name: sockets.ExprSocketDef(**expr_info) - for sym_name, expr_info in params.sym_expr_infos(info).items() + dim_name: sockets.ExprSocketDef(**expr_info) + for dim_name, expr_info in params.sym_expr_infos(info).items() } elif self.loose_input_sockets: diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index e38dc41..b7fc830 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -118,13 +118,15 @@ class ExprBLSocket(base.MaxwellSimSocket): physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical) # Symbols - # active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) - symbols: frozenset[sp.Symbol] = bl_cache.BLField(frozenset()) + output_name: sim_symbols.SimSymbolName = bl_cache.BLField( + sim_symbols.SimSymbolName.Expr + ) + active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([]) - # @property - # def symbols(self) -> set[sp.Symbol]: - # """Current symbols as an unordered set.""" - # return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols} + @property + def symbols(self) -> set[sp.Symbol]: + """Current symbols as an unordered set.""" + return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols} @bl_cache.cached_bl_property(depends_on={'symbols'}) def sorted_symbols(self) -> list[sp.Symbol]: @@ -184,6 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket): ) # UI: Info + show_func_ui: bool = bl_cache.BLField(True) show_info_columns: bool = bl_cache.BLField(False) info_columns: set[InfoDisplayCol] = bl_cache.BLField( {InfoDisplayCol.Length, InfoDisplayCol.MathType} @@ -615,35 +618,24 @@ class ExprBLSocket(base.MaxwellSimSocket): Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. """ + output_sim_sym = ( + sim_symbols.SimSymbol( + sym_name=self.output_name, + mathtype=self.mathtype, + physical_type=self.physical_type, + unit=self.unit, + rows=self.size.rows, + cols=self.size.cols, + ), + ) if self.symbols: return ct.InfoFlow( - dim_names=[sym.name for sym in self.sorted_symbols], - dim_idx={ - sym.name: ct.RangeFlow( - start=-sp.oo if _check_sym_oo(sym) else -sp.zoo, - stop=sp.oo if _check_sym_oo(sym) else sp.zoo, - steps=0, - unit=None, ## Symbols alone are unitless. - ) - ## TODO: PhysicalTypes for symbols? Or nah? - ## TODO: Can we parse some sp.Interval for explicit domains? - ## -> We investigated sp.Symbol(..., domain=...). - ## -> It's no good. We can't re-extract the interval given to domain. - for sym in self.sorted_symbols - }, - output_name='_', ## Use node:socket name? Or something? Ahh - output_shape=self.size.shape, - output_mathtype=self.mathtype, - output_unit=self.unit, + dims={sim_sym: None for sim_sym in self.active_symbols}, + output=output_sim_sym, ) # Constant - return ct.InfoFlow( - output_name='_', ## Use node:socket name? Or something? Ahh - output_shape=self.size.shape, - output_mathtype=self.mathtype, - output_unit=self.unit, - ) + return ct.InfoFlow(output=output_sim_sym) #################### # - FlowKind: Capabilities @@ -847,26 +839,31 @@ class ExprBLSocket(base.MaxwellSimSocket): Uses `draw_value` to draw the base UI """ - # Physical Type Selector - ## -> Determines whether/which unit-dropdown will be shown. - col.prop(self, self.blfields['physical_type'], text='') + if self.show_func_ui: + # Output Name Selector + ## -> The name of the output + col.prop(self, self.blfields['output_name'], text='') - # Non-Symbolic: Size/Mathtype Selector - ## -> Symbols imply str expr input. - ## -> For arbitrary str exprs, size/mathtype are derived from the expr. - ## -> Otherwise, size/mathtype must be pre-specified for a nice UI. - if not self.symbols: - row = col.row(align=True) - row.prop(self, self.blfields['size'], text='') - row.prop(self, self.blfields['mathtype'], text='') + # Physical Type Selector + ## -> Determines whether/which unit-dropdown will be shown. + col.prop(self, self.blfields['physical_type'], text='') - # Base UI - ## -> Draws the UI appropriate for the above choice of constraints. - self.draw_value(col) + # Non-Symbolic: Size/Mathtype Selector + ## -> Symbols imply str expr input. + ## -> For arbitrary str exprs, size/mathtype are derived from the expr. + ## -> Otherwise, size/mathtype must be pre-specified for a nice UI. + if not self.symbols: + row = col.row(align=True) + row.prop(self, self.blfields['size'], text='') + row.prop(self, self.blfields['mathtype'], text='') - # Symbol UI - ## -> Draws the UI appropriate for the above choice of constraints. - ## -> TODO + # Base UI + ## -> Draws the UI appropriate for the above choice of constraints. + self.draw_value(col) + + # Symbol UI + ## -> Draws the UI appropriate for the above choice of constraints. + ## -> TODO #################### # - UI: InfoFlow @@ -884,9 +881,9 @@ class ExprBLSocket(base.MaxwellSimSocket): ) # Dimensions - for dim_name in info.dim_names: - dim_idx = info.dim_idx[dim_name] - grid.label(text=dim_name) + for dim in info.dims: + dim_idx = info.dims[dim] + grid.label(text=dim.name_pretty) if InfoDisplayCol.Length in self.info_columns: grid.label(text=str(len(dim_idx))) if InfoDisplayCol.MathType in self.info_columns: @@ -895,27 +892,27 @@ class ExprBLSocket(base.MaxwellSimSocket): grid.label(text=spux.sp_to_str(dim_idx.unit)) # Outputs - grid.label(text=info.output_name) + grid.label(text=info.output.name_pretty) if InfoDisplayCol.Length in self.info_columns: grid.label(text='', icon=ct.Icon.DataSocketOutput) if InfoDisplayCol.MathType in self.info_columns: grid.label( text=( - spux.MathType.to_str(info.output_mathtype) + spux.MathType.to_str(info.output.mathtype) + ( 'ˣ'.join( [ unicode_superscript(out_axis) - for out_axis in info.output_shape + for out_axis in info.output.shape ] ) - if info.output_shape + if info.output.shape else '' ) ) ) if InfoDisplayCol.Unit in self.info_columns: - grid.label(text=f'{spux.sp_to_str(info.output_unit)}') + grid.label(text=f'{spux.sp_to_str(info.output.unit)}') #################### @@ -929,6 +926,7 @@ class ExprSocketDef(base.SocketDef): ct.FlowKind.Array, ct.FlowKind.Func, ] = ct.FlowKind.Value + output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName # Socket Interface size: spux.NumberSize1D = spux.NumberSize1D.Scalar @@ -938,10 +936,6 @@ class ExprSocketDef(base.SocketDef): default_unit: spux.Unit | None = None default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list) - @property - def symbols(self) -> set[sp.Symbol]: - return {sim_symbol.sp_symbol for sim_symbol in self.default_symbols} - # FlowKind: Value default_value: spux.SympyExpr = 0 abs_min: spux.SympyExpr | None = None @@ -954,6 +948,7 @@ class ExprSocketDef(base.SocketDef): default_scaling: ct.ScalingMode = ct.ScalingMode.Lin # UI + show_func_ui: bool = True show_info_columns: bool = False #################### @@ -1153,6 +1148,14 @@ class ExprSocketDef(base.SocketDef): msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}' raise ValueError(msg) + # Coerce from Infinite + if bound.is_infinite and self.mathtype is spux.MathType.Integer: + new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1) + if bound.is_infinite and self.mathtype is spux.MathType.Rational: + new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1) + if bound.is_infinite and self.mathtype is spux.MathType.Real: + new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0) + if new_bounds[0] is not None: self.default_min = new_bounds[0] if new_bounds[1] is not None: @@ -1217,13 +1220,14 @@ class ExprSocketDef(base.SocketDef): #################### def init(self, bl_socket: ExprBLSocket) -> None: bl_socket.active_kind = self.active_kind + bl_socket.output_name = self.output_name # Socket Interface ## -> Recall that auto-updates are turned off during init() bl_socket.size = self.size bl_socket.mathtype = self.mathtype bl_socket.physical_type = self.physical_type - bl_socket.symbols = self.symbols + bl_socket.active_symbols = self.symbols # FlowKind.Value ## -> We must take units into account when setting bl_socket.value @@ -1246,6 +1250,7 @@ class ExprSocketDef(base.SocketDef): ) # UI + bl_socket.show_func_ui = self.show_func_ui bl_socket.show_info_columns = self.show_info_columns # Info Draw diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index 871455a..5cfde7f 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -61,7 +61,6 @@ SympyType = ( class MathType(enum.StrEnum): """Type identifiers that encompass common sets of mathematical objects.""" - Bool = enum.auto() Integer = enum.auto() Rational = enum.auto() Real = enum.auto() @@ -77,8 +76,6 @@ class MathType(enum.StrEnum): return MathType.Rational if MathType.Integer in mathtypes: return MathType.Integer - if MathType.Bool in mathtypes: - return MathType.Bool msg = f"Can't combine mathtypes {mathtypes}" raise ValueError(msg) @@ -88,7 +85,6 @@ class MathType(enum.StrEnum): return ( other in { - MT.Bool: [MT.Bool], MT.Integer: [MT.Integer], MT.Rational: [MT.Integer, MT.Rational], MT.Real: [MT.Integer, MT.Rational, MT.Real], @@ -98,11 +94,9 @@ class MathType(enum.StrEnum): def coerce_compatible_pyobj( self, pyobj: bool | int | Fraction | float | complex - ) -> bool | int | Fraction | float | complex: + ) -> int | Fraction | float | complex: MT = MathType match self: - case MT.Bool: - return pyobj case MT.Integer: return int(pyobj) case MT.Rational if isinstance(pyobj, int): @@ -123,8 +117,6 @@ class MathType(enum.StrEnum): *[MathType.from_expr(v) for v in sp.flatten(sp_obj)] ) - if isinstance(sp_obj, sp.logic.boolalg.Boolean): - return MathType.Bool if sp_obj.is_integer: return MathType.Integer if sp_obj.is_rational: @@ -146,7 +138,6 @@ class MathType(enum.StrEnum): @staticmethod def from_pytype(dtype: type) -> type: return { - bool: MathType.Bool, int: MathType.Integer, Fraction: MathType.Rational, float: MathType.Real, @@ -166,7 +157,6 @@ class MathType(enum.StrEnum): def pytype(self) -> type: MT = MathType return { - MT.Bool: bool, MT.Integer: int, MT.Rational: Fraction, MT.Real: float, @@ -177,17 +167,25 @@ class MathType(enum.StrEnum): def symbolic_set(self) -> type: MT = MathType return { - MT.Bool: sp.Set([sp.S(False), sp.S(True)]), MT.Integer: sp.Integers, MT.Rational: sp.Rationals, MT.Real: sp.Reals, MT.Complex: sp.Complexes, }[self] + @property + def sp_symbol_a(self) -> type: + MT = MathType + return { + MT.Integer: sp.Symbol('a', integer=True), + MT.Rational: sp.Symbol('a', rational=True), + MT.Real: sp.Symbol('a', real=True), + MT.Complex: sp.Symbol('a', complex=True), + }[self] + @staticmethod def to_str(value: typ.Self) -> type: return { - MathType.Bool: 'T|F', MathType.Integer: 'ℤ', MathType.Rational: 'ℚ', MathType.Real: 'ℝ', @@ -212,6 +210,9 @@ class MathType(enum.StrEnum): ) +#################### +# - Size: 1D +#################### class NumberSize1D(enum.StrEnum): """Valid 1D-constrained shape.""" @@ -278,6 +279,20 @@ class NumberSize1D(enum.StrEnum): (4, 1): NS.Vec4, }[shape] + @property + def rows(self): + NS = NumberSize1D + return { + NS.Scalar: 1, + NS.Vec2: 2, + NS.Vec3: 3, + NS.Vec4: 4, + }[self] + + @property + def cols(self): + return 1 + @property def shape(self): NS = NumberSize1D @@ -297,6 +312,30 @@ def symbol_range(sym: sp.Symbol) -> str: ) +#################### +# - Symbol Sizes +#################### +class SimpleSize2D(enum.StrEnum): + """Simple subset of sizes for rank-2 tensors.""" + + Scalar = enum.auto() + + # Vectors + Vec2 = enum.auto() ## 2x1 + Vec3 = enum.auto() ## 3x1 + Vec4 = enum.auto() ## 4x1 + + # Covectors + CoVec2 = enum.auto() ## 1x2 + CoVec3 = enum.auto() ## 1x3 + CoVec4 = enum.auto() ## 1x4 + + # Square Matrices + Mat22 = enum.auto() ## 2x2 + Mat33 = enum.auto() ## 3x3 + Mat44 = enum.auto() ## 4x4 + + #################### # - Unit Dimensions #################### @@ -382,6 +421,8 @@ UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = { unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity) } | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)} +UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()} + #################### # - Expr Analysis: Units @@ -907,10 +948,6 @@ class PhysicalType(enum.StrEnum): LumIntensity = enum.auto() LumFlux = enum.auto() Illuminance = enum.auto() - # Optics - OrdinaryWaveVector = enum.auto() - AngularWaveVector = enum.auto() - PoyntingVector = enum.auto() @functools.cached_property def unit_dim(self): @@ -956,10 +993,6 @@ class PhysicalType(enum.StrEnum): PT.LumIntensity: Dims.luminous_intensity, PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension, PT.Illuminance: Dims.luminous_intensity / Dims.length**2, - # Optics - PT.OrdinaryWaveVector: Dims.frequency, - PT.AngularWaveVector: Dims.angle * Dims.frequency, - PT.PoyntingVector: Dims.power / Dims.length**2, }[self] @functools.cached_property @@ -1196,10 +1229,6 @@ class PhysicalType(enum.StrEnum): PT.HField: [None, (2,), (3,)], # Luminal PT.LumFlux: [None, (2,), (3,)], - # Optics - PT.OrdinaryWaveVector: [None, (2,), (3,)], - PT.AngularWaveVector: [None, (2,), (3,)], - PT.PoyntingVector: [None, (2,), (3,)], } return overrides.get(self, [None]) @@ -1222,7 +1251,6 @@ class PhysicalType(enum.StrEnum): - **Charge**: Generally, it is real. However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: - **Conductance**: The imaginary part represents the extinction, in the Drude-model sense. - - **Poynting**: The imaginary part represents the oscillation in the power flux over time. """ MT = MathType @@ -1249,10 +1277,6 @@ class PhysicalType(enum.StrEnum): PT.EField: [MT.Real, MT.Complex], ## Im -> Phase PT.HField: [MT.Real, MT.Complex], ## Im -> Phase # Luminal - # Optics - PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase - PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase - PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power } return overrides.get(self, [MT.Real]) @@ -1323,10 +1347,6 @@ UNITS_SI: UnitSystem = { _PT.LumIntensity: spu.candela, _PT.LumFlux: lumen, _PT.Illuminance: spu.lux, - # Optics - _PT.OrdinaryWaveVector: spu.hertz, - _PT.AngularWaveVector: spu.radian * spu.hertz, - _PT.PoyntingVector: spu.watt / spu.meter**2, } @@ -1380,15 +1400,20 @@ def sympy_to_python( #################### # - Convert to Unit System #################### -def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: +def convert_to_unit_system( + sp_obj: SympyExpr, unit_system: UnitSystem | None +) -> SympyExpr: """Convert an expression to the units of a given unit system, with appropriate scaling.""" + if unit_system is None: + return sp_obj + return spu.convert_to( sp_obj, {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, ) -def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: +def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr: """Strip units occurring in the given unit system from the expression. Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`". @@ -1397,11 +1422,13 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: Notes: You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`. """ + if unit_system is None: + return sp_obj.subs(UNIT_TO_1) return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None}) def scale_to_unit_system( - sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False + sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False ) -> int | float | complex | tuple | jax.Array: """Convert an expression to the units of a given unit system, then strip all units of the unit system. diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py index da51681..f4026a3 100644 --- a/src/blender_maxwell/utils/image_ops.py +++ b/src/blender_maxwell/utils/image_ops.py @@ -29,11 +29,13 @@ import matplotlib.axis as mpl_ax import matplotlib.backends.backend_agg import matplotlib.figure import matplotlib.style as mplstyle +import seaborn as sns from blender_maxwell import contracts as ct from blender_maxwell.utils import logger mplstyle.use('fast') ## TODO: Does this do anything? +sns.set_theme() log = logger.get(__name__) @@ -149,125 +151,98 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int): #################### # - Plotters #################### -# () -> ℝ -def plot_hist_1d( - data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis -) -> None: - y_name = info.output_name - y_unit = info.output_unit - - ax.hist(data, bins=30, alpha=0.75) - ax.set_title('Histogram') - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) - - # (ℤ) -> ℝ def plot_box_plot_1d( data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis ) -> None: - x_name = info.dim_names[0] - y_name = info.output_name - y_unit = info.output_unit + x_sym = info.last_dim + y_sym = info.output - ax.boxplot(data) - ax.set_title('Box Plot') - ax.set_xlabel(f'{x_name}') - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.boxplot([data]) + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) + + +def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: + x_sym = info.last_dim + y_sym = info.output + + p = ax.bar(info.dims[x_sym], data) + ax.bar_label(p, label_type='center') + + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) # (ℝ) -> ℝ def plot_curve_2d( data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis ) -> None: - times = [time.perf_counter()] + x_sym = info.last_dim + y_sym = info.output - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.output_name - y_unit = info.output_unit - - times.append(time.perf_counter() - times[0]) - ax.plot(info.dim_idx_arrays[0], data) - times.append(time.perf_counter() - times[0]) - ax.set_title('2D Curve') - times.append(time.perf_counter() - times[0]) - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - times.append(time.perf_counter() - times[0]) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) - times.append(time.perf_counter() - times[0]) - # log.critical('Timing of Curve2D: %s', str(times)) + ax.plot(info.dims[x_sym].realize_array.values, data) + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) def plot_points_2d( data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis ) -> None: - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.output_name - y_unit = info.output_unit + x_sym = info.last_dim + y_sym = info.output - ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) - ax.set_title('2D Points') - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) - - -def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.output_name - y_unit = info.output_unit - - ax.bar(info.dim_idx_arrays[0], data, alpha=0.7) - ax.set_title('2D Bar') - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.scatter(x_sym.realize_array.values, data) + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) # (ℝ, ℤ) -> ℝ def plot_curves_2d( data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis ) -> None: - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.output_name - y_unit = info.output_unit + x_sym = info.first_dim + y_sym = info.output - for category in range(data.shape[1]): - ax.plot(info.dim_idx_arrays[0], data[:, category]) + for i, category in enumerate(info.dims[info.last_dim]): + ax.plot(info.dims[x_sym], data[:, i], label=category) - ax.set_title('2D Curves') - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) ax.legend() def plot_filled_curves_2d( data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis ) -> None: - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.output_name - y_unit = info.output_unit + x_sym = info.first_dim + y_sym = info.output - shared_x_idx = info.dim_idx_arrays[0] + shared_x_idx = info.dims[info.last_dim] ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1]) - ax.set_title('2D Filled Curves') - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) + ax.legend() # (ℝ, ℝ) -> ℝ def plot_heatmap_2d( data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis ) -> None: - x_name = info.dim_names[0] - x_unit = info.dim_units[x_name] - y_name = info.dim_names[1] - y_unit = info.dim_units[y_name] + x_sym = info.first_dim + y_sym = info.last_dim + c_sym = info.output - heatmap = ax.imshow(data, aspect='auto', interpolation='none') - # ax.figure.colorbar(heatmap, ax=ax) - ax.set_title('Heatmap') - ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else '')) - ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else '')) + heatmap = ax.imshow(data, aspect='equal', interpolation='none') + ax.figure.colorbar(heatmap, cax=ax) + + ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}') + ax.set_xlabel(x_sym.plot_label) + ax.set_xlabel(y_sym.plot_label) + ax.legend() diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py index df13a4c..7d1544e 100644 --- a/src/blender_maxwell/utils/sim_symbols.py +++ b/src/blender_maxwell/utils/sim_symbols.py @@ -18,26 +18,67 @@ import dataclasses import enum import sys import typing as typ +from fractions import Fraction import sympy as sp from . import extra_sympy_units as spux +int_min = -(2**64) +int_max = 2**64 +float_min = sys.float_info.min +float_max = sys.float_info.max + #################### -# - Simulation Symbols +# - Simulation Symbol Names #################### class SimSymbolName(enum.StrEnum): + # Lower LowerA = enum.auto() + LowerB = enum.auto() + LowerC = enum.auto() + LowerD = enum.auto() + LowerI = enum.auto() LowerT = enum.auto() LowerX = enum.auto() LowerY = enum.auto() LowerZ = enum.auto() - # Physics + # Fields + Ex = enum.auto() + Ey = enum.auto() + Ez = enum.auto() + Hx = enum.auto() + Hy = enum.auto() + Hz = enum.auto() + + Er = enum.auto() + Etheta = enum.auto() + Ephi = enum.auto() + Hr = enum.auto() + Htheta = enum.auto() + Hphi = enum.auto() + + # Optics Wavelength = enum.auto() Frequency = enum.auto() + Flux = enum.auto() + + PermXX = enum.auto() + PermYY = enum.auto() + PermZZ = enum.auto() + + DiffOrderX = enum.auto() + DiffOrderY = enum.auto() + + # Generic + Expr = enum.auto() + + #################### + # - UI + #################### @staticmethod def to_name(v: typ.Self) -> str: """Convert the enum value to a human-friendly name. @@ -50,27 +91,6 @@ class SimSymbolName(enum.StrEnum): """ return SimSymbolName(v).name - @property - def name(self) -> str: - SSN = SimSymbolName - return { - SSN.LowerA: 'a', - SSN.LowerT: 't', - SSN.LowerX: 'x', - SSN.LowerY: 'y', - SSN.LowerZ: 'z', - SSN.Wavelength: 'wl', - SSN.Frequency: 'freq', - }[self] - - @property - def name_pretty(self) -> str: - SSN = SimSymbolName - return { - SSN.Wavelength: 'λ', - SSN.Frequency: '𝑓', - }.get(self, self.name) - @staticmethod def to_icon(_: typ.Self) -> str: """Convert the enum value to a Blender icon. @@ -83,6 +103,75 @@ class SimSymbolName(enum.StrEnum): """ return '' + #################### + # - Computed Properties + #################### + @property + def name(self) -> str: + SSN = SimSymbolName + return { + # Lower + SSN.LowerA: 'a', + SSN.LowerB: 'b', + SSN.LowerC: 'c', + SSN.LowerD: 'd', + SSN.LowerI: 'i', + SSN.LowerT: 't', + SSN.LowerX: 'x', + SSN.LowerY: 'y', + SSN.LowerZ: 'z', + # Fields + SSN.Ex: 'Ex', + SSN.Ey: 'Ey', + SSN.Ez: 'Ez', + SSN.Hx: 'Hx', + SSN.Hy: 'Hy', + SSN.Hz: 'Hz', + SSN.Er: 'Ex', + SSN.Etheta: 'Ey', + SSN.Ephi: 'Ez', + SSN.Hr: 'Hx', + SSN.Htheta: 'Hy', + SSN.Hphi: 'Hz', + # Optics + SSN.Wavelength: 'wl', + SSN.Frequency: 'freq', + SSN.Flux: 'flux', + SSN.PermXX: 'eps_xx', + SSN.PermYY: 'eps_yy', + SSN.PermZZ: 'eps_zz', + SSN.DiffOrderX: 'order_x', + SSN.DiffOrderY: 'order_y', + # Generic + SSN.Expr: 'expr', + }[self] + + @property + def name_pretty(self) -> str: + SSN = SimSymbolName + return { + SSN.Wavelength: 'λ', + SSN.Frequency: '𝑓', + }.get(self, self.name) + + +#################### +# - Simulation Symbol +#################### +def mk_interval( + interval_finite: tuple[int | Fraction | float, int | Fraction | float], + interval_inf: tuple[bool, bool], + interval_closed: tuple[bool, bool], + unit_factor: typ.Literal[1] | spux.Unit, +) -> sp.Interval: + """Create a symbolic interval from the tuples (and unit) defining it.""" + return sp.Interval( + start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo), + end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo), + left_open=(True if interval_inf[0] else not interval_closed[0]), + right_open=(True if interval_inf[1] else not interval_closed[1]), + ) + @dataclasses.dataclass(kw_only=True, frozen=True) class SimSymbol: @@ -94,66 +183,145 @@ class SimSymbol: It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols. """ - sim_node_name: SimSymbolName = SimSymbolName.LowerX + sym_name: SimSymbolName mathtype: spux.MathType = spux.MathType.Real - physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical - ## TODO: Shape/size support? Incl. MatrixSymbol. + # Units + ## -> 'None' indicates that no particular unit has yet been chosen. + ## -> Not exposed in the UI; must be set some other way. + unit: spux.Unit | None = None - # Domain - interval_finite: tuple[float, float] = (0, 1) + # Size + ## -> All SimSymbol sizes are "2D", but interpreted by convention. + ## -> 1x1: "Scalar". + ## -> nx1: "Vector". + ## -> 1xn: "Covector". + ## -> nxn: "Matrix". + rows: int = 1 + cols: int = 1 + + # Scalar Domain: "Interval" + ## -> NOTE: interval_finite_*[0] must be strictly smaller than [1]. + ## -> See self.domain. + ## -> We have to deconstruct symbolic interval semantics a bit for UI. + interval_finite_z: tuple[int, int] = (0, 1) + interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1)) + interval_finite_re: tuple[float, float] = (0, 1) interval_inf: tuple[bool, bool] = (True, True) interval_closed: tuple[bool, bool] = (False, False) + interval_finite_im: tuple[float, float] = (0, 1) + interval_inf_im: tuple[bool, bool] = (True, True) + interval_closed_im: tuple[bool, bool] = (False, False) + #################### # - Properties #################### @property def name(self) -> str: - return self.sim_node_name.name + """Usable name for the symbol.""" + return self.sym_name.name + + @property + def name_pretty(self) -> str: + """Pretty (possibly unicode) name for the thing.""" + return self.sym_name.name_pretty + ## TODO: Formatting conventions for bolding/etc. of vectors/mats/... + + @property + def plot_label(self) -> str: + """Pretty plot-oriented label.""" + return f'{self.name_pretty}' + ( + f'({self.unit})' if self.unit is not None else '' + ) + + @property + def unit_factor(self) -> spux.SympyExpr: + """Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking.""" + return self.unit if self.unit is not None else sp.S(1) + + @property + def shape(self) -> tuple[int, ...]: + match (self.rows, self.cols): + case (1, 1): + return () + case (_, 1): + return (self.rows,) + case (1, _): + return (1, self.rows) + case (_, _): + return (self.rows, self.cols) @property def domain(self) -> sp.Interval | sp.Set: - """Return the domain of valid values for the symbol. + """Return the scalar domain of valid values for each element of the symbol. For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties. This interval **must** have the property`start <= stop`. Otherwise, the domain is the symbolic set corresponding to `self.mathtype`. """ - if self.mathtype in [ - spux.MathType.Integer, - spux.MathType.Rational, - spux.MathType.Real, - ]: - return sp.Interval( - start=self.interval_finite[0] if not self.interval_inf[0] else -sp.oo, - end=self.interval_finite[1] if not self.interval_inf[1] else sp.oo, - left_open=( - True if self.interval_inf[0] else not self.interval_closed[0] - ), - right_open=( - True if self.interval_inf[1] else not self.interval_closed[1] - ), - ) + match self.mathtype: + case spux.MathType.Integer: + return mk_interval( + self.interval_finite_z, + self.interval_inf, + self.interval_closed, + self.unit_factor, + ) - return self.mathtype.symbolic_set + case spux.MathType.Rational: + return mk_interval( + Fraction(*self.interval_finite_q), + self.interval_inf, + self.interval_closed, + self.unit_factor, + ) + + case spux.MathType.Real: + return mk_interval( + self.interval_finite_re, + self.interval_inf, + self.interval_closed, + self.unit_factor, + ) + + case spux.MathType.Complex: + return ( + mk_interval( + self.interval_finite_re, + self.interval_inf, + self.interval_closed, + self.unit_factor, + ), + mk_interval( + self.interval_finite_im, + self.interval_inf_im, + self.interval_closed_im, + self.unit_factor, + ), + ) #################### # - Properties #################### @property def sp_symbol(self) -> sp.Symbol: - """Return a symbolic variable corresponding to this `SimSymbol`. + """Return a symbolic variable w/unit, corresponding to this `SimSymbol`. As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined. - However, the assumptions system alone is rather limited, and implementations should therefore also strongly consider transporting `SimSymbols` directly, instead of `sp.Symbol`. - This allows making use of other properties like `self.domain`, when appropriate. + - **MathType**: Depending on `self.mathtype`. + - **Positive/Negative**: Depending on `self.domain`. + - **Nonzero**: Depending on `self.domain`, including open/closed boundary specifications. + + Notes: + **The assumptions system is rather limited**, and implementations should strongly consider transporting `SimSymbols` instead of `sp.Symbol`. + + This allows tracking ex. the valid interval domain for a symbol. """ - # MathType Domain Constraint - ## -> We must feed the assumptions system. + # MathType Assumption mathtype_kwargs = {} match self.mathtype: case spux.MathType.Integer: @@ -165,53 +333,138 @@ class SimSymbol: case spux.MathType.Complex: mathtype_kwargs |= {'complex': True} - # Interval Constraints - if isinstance(self.domain, sp.Interval): - # Assumption: Non-Zero - if ( - ( - self.domain.left == 0 - and self.domain.left_open - or self.domain.right == 0 - and self.domain.right_open - ) - or self.domain.left > 0 - or self.domain.right < 0 - ): - mathtype_kwargs |= {'nonzero': True} + # Non-Zero Assumption + if ( + ( + self.domain.left == 0 + and self.domain.left_open + or self.domain.right == 0 + and self.domain.right_open + ) + or self.domain.left > 0 + or self.domain.right < 0 + ): + mathtype_kwargs |= {'nonzero': True} - # Assumption: Positive/Negative - if self.domain.left >= 0: - mathtype_kwargs |= {'positive': True} - elif self.domain.right <= 0: - mathtype_kwargs |= {'negative': True} + # Positive/Negative Assumption + if self.domain.left >= 0: + mathtype_kwargs |= {'positive': True} + elif self.domain.right <= 0: + mathtype_kwargs |= {'negative': True} - # Construct the Symbol - return sp.Symbol(self.sim_node_name.name, **mathtype_kwargs) + return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor + + #################### + # - Operations + #################### + def update(self, **kwargs) -> typ.Self: + def get_attr(attr: str): + _notfound = 'notfound' + if kwargs.get(attr, _notfound) is _notfound: + return getattr(self, attr) + return kwargs[attr] + + return SimSymbol( + sym_name=get_attr('sym_name'), + mathtype=get_attr('mathtype'), + physical_type=get_attr('physical_type'), + unit=get_attr('unit'), + rows=get_attr('rows'), + cols=get_attr('cols'), + interval_finite_z=get_attr('interval_finite_z'), + interval_finite_q=get_attr('interval_finite_q'), + interval_finite_re=get_attr('interval_finite_q'), + interval_inf=get_attr('interval_inf'), + interval_closed=get_attr('interval_closed'), + interval_finite_im=get_attr('interval_finite_im'), + interval_inf_im=get_attr('interval_inf_im'), + interval_closed_im=get_attr('interval_closed_im'), + ) + + def set_size(self, rows: int, cols: int) -> typ.Self: + return SimSymbol( + sym_name=self.sym_name, + mathtype=self.mathtype, + physical_type=self.physical_type, + unit=self.unit, + rows=rows, + cols=cols, + interval_finite_z=self.interval_finite_z, + interval_finite_q=self.interval_finite_q, + interval_finite_re=self.interval_finite_re, + interval_inf=self.interval_inf, + interval_closed=self.interval_closed, + interval_finite_im=self.interval_finite_im, + interval_inf_im=self.interval_inf_im, + interval_closed_im=self.interval_closed_im, + ) #################### # - Common Sim Symbols #################### class CommonSimSymbol(enum.StrEnum): - """A set of pre-defined symbols that might commonly be used in the context of physical simulation. + """Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded. - Each entry maps directly to a particular `SimSymbol`. - - The enum is compatible with `BLField`, making it easy to declare a UI-driven dropdown of symbols that behave as expected. + The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected. Attributes: + X: + Time: A symbol representing a real-valued wavelength. Wavelength: A symbol representing a real-valued wavelength. Implicitly, this symbol often represents "vacuum wavelength" in particular. Wavelength: A symbol representing a real-valued frequency. Generally, this is the non-angular frequency. """ - X = enum.auto() + Index = enum.auto() + + # Space|Time + SpaceX = enum.auto() + SpaceY = enum.auto() + SpaceZ = enum.auto() + + AngR = enum.auto() + AngTheta = enum.auto() + AngPhi = enum.auto() + + DirX = enum.auto() + DirY = enum.auto() + DirZ = enum.auto() + Time = enum.auto() + + # Fields + FieldEx = enum.auto() + FieldEy = enum.auto() + FieldEz = enum.auto() + FieldHx = enum.auto() + FieldHy = enum.auto() + FieldHz = enum.auto() + + FieldEr = enum.auto() + FieldEtheta = enum.auto() + FieldEphi = enum.auto() + FieldHr = enum.auto() + FieldHtheta = enum.auto() + FieldHphi = enum.auto() + + # Optics Wavelength = enum.auto() Frequency = enum.auto() + DiffOrderX = enum.auto() + DiffOrderY = enum.auto() + + Flux = enum.auto() + + WaveVecX = enum.auto() + WaveVecY = enum.auto() + WaveVecZ = enum.auto() + + #################### + # - UI + #################### @staticmethod def to_name(v: typ.Self) -> str: """Convert the enum value to a human-friendly name. @@ -222,7 +475,7 @@ class CommonSimSymbol(enum.StrEnum): Returns: A human-friendly name corresponding to the enum value. """ - return CommonSimSymbol(v).sim_symbol_name.name + return CommonSimSymbol(v).name @staticmethod def to_icon(_: typ.Self) -> str: @@ -241,55 +494,125 @@ class CommonSimSymbol(enum.StrEnum): #################### @property def name(self) -> str: - return self.sim_symbol.name - - @property - def sim_symbol_name(self) -> str: SSN = SimSymbolName CSS = CommonSimSymbol return { - CSS.X: SSN.LowerX, + CSS.Index: SSN.LowerI, + # Space|Time + CSS.SpaceX: SSN.LowerX, + CSS.SpaceY: SSN.LowerY, + CSS.SpaceZ: SSN.LowerZ, + CSS.AngR: SSN.LowerR, + CSS.AngTheta: SSN.LowerTheta, + CSS.AngPhi: SSN.LowerPhi, + CSS.DirX: SSN.LowerX, + CSS.DirY: SSN.LowerY, + CSS.DirZ: SSN.LowerZ, CSS.Time: SSN.LowerT, - CSS.Wavelength: SSN.Wavelength, + # Fields + CSS.FieldEx: SSN.Ex, + CSS.FieldEy: SSN.Ey, + CSS.FieldEz: SSN.Ez, + CSS.FieldHx: SSN.Hx, + CSS.FieldHy: SSN.Hy, + CSS.FieldHz: SSN.Hz, + CSS.FieldEr: SSN.Er, + CSS.FieldHr: SSN.Hr, + # Optics CSS.Frequency: SSN.Frequency, + CSS.Wavelength: SSN.Wavelength, + CSS.DiffOrderX: SSN.DiffOrderX, + CSS.DiffOrderY: SSN.DiffOrderY, }[self] - @property - def sim_symbol(self) -> SimSymbol: + def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol: """Retrieve the `SimSymbol` associated with the `CommonSimSymbol`.""" CSS = CommonSimSymbol + + # Space + sym_space = SimSymbol( + sym_name=self.name, + physical_type=spux.PhysicalType.Length, + unit=unit, + ) + sym_ang = SimSymbol( + sym_name=self.name, + physical_type=spux.PhysicalType.Angle, + unit=unit, + ) + + # Fields + def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol: + return SimSymbol( + sym_name=self.name, + physical_type=spux.PhysicalType.EField + if eh == 'e' + else spux.PhysicalType.HField, + unit=unit, + interval_finite_re=(0, sys.float_info.max), + interval_inf_re=(False, True), + interval_closed_re=(True, False), + interval_finite_im=(sys.float_info.min, sys.float_info.max), + interval_inf_im=(True, True), + ) + return { - CSS.X: SimSymbol( - sim_node_name=self.sim_symbol_name, - mathtype=spux.MathType.Real, - physical_type=spux.PhysicalType.NonPhysical, - ## TODO: Unit of Picosecond - interval_finite=(sys.float_info.min, sys.float_info.max), - interval_inf=(True, True), - interval_closed=(False, False), - ), - CSS.Time: SimSymbol( - sim_node_name=self.sim_symbol_name, - mathtype=spux.MathType.Real, - physical_type=spux.PhysicalType.Time, - ## TODO: Unit of Picosecond - interval_finite=(0, sys.float_info.max), + CSS.Index: SimSymbol( + sym_name=self.name, + mathtype=spux.MathType.Integer, + interval_finite_z=(0, 2**64), interval_inf=(False, True), interval_closed=(True, False), ), + # Space|Time + CSS.SpaceX: sym_space, + CSS.SpaceY: sym_space, + CSS.SpaceZ: sym_space, + CSS.AngR: sym_space, + CSS.AngTheta: sym_ang, + CSS.AngPhi: sym_ang, + CSS.Time: SimSymbol( + sym_name=self.name, + physical_type=spux.PhysicalType.Time, + unit=unit, + interval_finite_re=(0, sys.float_info.max), + interval_inf=(False, True), + interval_closed=(True, False), + ), + # Fields + CSS.FieldEx: sym_field('e'), + CSS.FieldEy: sym_field('e'), + CSS.FieldEz: sym_field('e'), + CSS.FieldHx: sym_field('h'), + CSS.FieldHy: sym_field('h'), + CSS.FieldHz: sym_field('h'), + CSS.FieldEr: sym_field('e'), + CSS.FieldEtheta: sym_field('e'), + CSS.FieldEphi: sym_field('e'), + CSS.FieldHr: sym_field('h'), + CSS.FieldHtheta: sym_field('h'), + CSS.FieldHphi: sym_field('h'), + CSS.Flux: SimSymbol( + sym_name=SimSymbolName.Flux, + mathtype=spux.MathType.Real, + physical_type=spux.PhysicalType.Power, + unit=unit, + ), + # Optics CSS.Wavelength: SimSymbol( - sim_node_name=self.sim_symbol_name, + sym_name=self.name, mathtype=spux.MathType.Real, physical_type=spux.PhysicalType.Length, - ## TODO: Unit of Picosecond + unit=unit, interval_finite=(0, sys.float_info.max), interval_inf=(False, True), interval_closed=(False, False), ), CSS.Frequency: SimSymbol( - sim_node_name=self.sim_symbol_name, + sym_name=self.name, mathtype=spux.MathType.Real, physical_type=spux.PhysicalType.Freq, + unit=unit, interval_finite=(0, sys.float_info.max), interval_inf=(False, True), interval_closed=(False, False), @@ -298,9 +621,33 @@ class CommonSimSymbol(enum.StrEnum): #################### -# - Selected Direct Access +# - Selected Direct-Access to SimSymbols #################### -x = CommonSimSymbol.X.sim_symbol +idx = CommonSimSymbol.Index.sim_symbol t = CommonSimSymbol.Time.sim_symbol wl = CommonSimSymbol.Wavelength.sim_symbol freq = CommonSimSymbol.Frequency.sim_symbol + +space_x = CommonSimSymbol.SpaceX.sim_symbol +space_y = CommonSimSymbol.SpaceY.sim_symbol +space_z = CommonSimSymbol.SpaceZ.sim_symbol + +dir_x = CommonSimSymbol.DirX.sim_symbol +dir_y = CommonSimSymbol.DirY.sim_symbol +dir_z = CommonSimSymbol.DirZ.sim_symbol + +ang_r = CommonSimSymbol.AngR.sim_symbol +ang_theta = CommonSimSymbol.AngTheta.sim_symbol +ang_phi = CommonSimSymbol.AngPhi.sim_symbol + +field_ex = CommonSimSymbol.FieldEx.sim_symbol +field_ey = CommonSimSymbol.FieldEy.sim_symbol +field_ez = CommonSimSymbol.FieldEz.sim_symbol +field_hx = CommonSimSymbol.FieldHx.sim_symbol +field_hy = CommonSimSymbol.FieldHx.sim_symbol +field_hz = CommonSimSymbol.FieldHx.sim_symbol + +flux = CommonSimSymbol.Flux.sim_symbol + +diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol +diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol