diff --git a/TODO.md b/TODO.md
index c3dc0a8..e69de29 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,56 +0,0 @@
-# Working TODO
-- [x] Wave Constant
-- Sources
- - [x] Temporal Shapes / Continuous Wave Temporal Shape
- - [x] Temporal Shapes / Symbolic Temporal Shape
- - [x] Plane Wave Source
- - [ ] TFSF Source
- - [x] Gaussian Beam Source
- - [ ] Astig. Gauss Beam
-- Monitors
- - [x] EH Field
- - [x] Power Flux
- - [x] Permittivity
- - [ ] Diffraction
-- Tidy3D / Integration
- - [ ] Exporter
- - [ ] Combine
- - [ ] Importer
-- Sim Grid
- - [ ] Sim Grid
- - [ ] Auto
- - [ ] Manual
- - [ ] Uniform
- - [ ] Data
-- Structures
- - [x] Cylinder
- - [ ] Cylinder Array
- - [ ] L-Cavity Cylinder
- - [ ] H-Cavity Cylinder
- - [ ] FCC Lattice
- - [ ] BCC Lattice
- - [ ] Monkey
-- Expr Socket
- - [x] LVF Mode
-- Math Nodes
- - [ ] Reduce Math
- - [x] Transform Math - reindex freq->wl
-- Material Data Fitting
- - [ ] Data File Import
- - [ ] DataFit Medium
-- Mediums
- - [ ] Non-Linearities
- - [ ] PEC Medium
- - [ ] Isotropic Medium
- - [ ] Sellmeier Medium
- - [ ] Drude Medium
- - [ ] Debye Medium
- - [ ] Anisotropic Medium
-- Integration
- - [x] Simulation and Analysis of Maxim's Cavity
-- Constants
- - [x] Number Constant
- - [x] Vector Constant
- - [x] Physical Constant
-
-- [x] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`.
diff --git a/pyproject.toml b/pyproject.toml
index 8e32d91..d17fda7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,13 +6,13 @@ authors = [
{ name = "Sofus Albert Høgsbro Rose", email = "blender-maxwell@sofusrose.com" }
]
dependencies = [
- "tidy3d>=2.6.3",
+ "tidy3d==2.7.0rc2",
"pydantic>=2.7.1",
"sympy==1.12",
"scipy==1.12.*",
"trimesh==4.2.*",
"networkx==3.2.*",
- "rich==12.5.*",
+ "rich>=13.7.1",
"rtree==1.2.*",
"jax[cpu]==0.4.26",
"msgspec[toml]==0.18.6",
@@ -28,6 +28,8 @@ dependencies = [
"certifi==2021.10.8",
"polars>=0.20.26",
"seaborn[stats]>=0.13.2",
+ "frozendict>=2.4.4",
+ "pydantic-tensor>=0.2.0",
]
## When it comes to dev-dep conflicts:
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 98731f6..133dcd4 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -7,13 +7,17 @@
# all-features: false
# with-sources: false
+absl-py==2.1.0
+ # via chex
+ # via optax
+ # via orbax-checkpoint
annotated-types==0.6.0
# via pydantic
argcomplete==3.3.0
# via commitizen
-boto3==1.23.1
+boto3==1.34.123
# via tidy3d
-botocore==1.26.10
+botocore==1.34.123
# via boto3
# via s3transfer
certifi==2021.10.8
@@ -23,7 +27,9 @@ cfgv==3.4.0
charset-normalizer==2.1.1
# via commitizen
# via requests
-click==8.0.3
+chex==0.1.86
+ # via optax
+click==8.1.7
# via dask
# via tidy3d
cloudpickle==3.0.0
@@ -31,8 +37,6 @@ cloudpickle==3.0.0
colorama==0.4.6
# via commitizen
commitizen==3.25.0
-commonmark==0.9.1
- # via rich
contourpy==1.2.0
# via matplotlib
cycler==0.12.1
@@ -43,13 +47,19 @@ decli==0.6.2
# via commitizen
distlib==0.3.8
# via virtualenv
+etils==1.9.1
+ # via orbax-checkpoint
fake-bpy-module-4-0==20231118
filelock==3.14.0
# via virtualenv
+flax==0.8.4
+ # via tidy3d
fonttools==4.49.0
# via matplotlib
+frozendict==2.4.4
fsspec==2024.2.0
# via dask
+ # via etils
h5netcdf==1.0.2
# via tidy3d
h5py==3.10.0
@@ -63,38 +73,61 @@ importlib-metadata==6.11.0
# via commitizen
# via dask
# via tidy3d
+importlib-resources==6.4.0
+ # via etils
jax==0.4.26
+ # via chex
+ # via flax
+ # via optax
+ # via orbax-checkpoint
jaxlib==0.4.26
+ # via chex
# via jax
+ # via optax
+ # via orbax-checkpoint
jaxtyping==0.2.28
jinja2==3.1.3
# via commitizen
jmespath==1.0.1
# via boto3
# via botocore
+joblib==1.4.2
+ # via tidy3d
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
+markdown-it-py==3.0.0
+ # via rich
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.3
# via seaborn
# via tidy3d
+mdurl==0.1.2
+ # via markdown-it-py
ml-dtypes==0.4.0
# via jax
# via jaxlib
+ # via tensorstore
mpmath==1.3.0
# via sympy
+msgpack==1.0.8
+ # via flax
+ # via orbax-checkpoint
msgspec==0.18.6
+nest-asyncio==1.6.0
+ # via orbax-checkpoint
networkx==3.2
nodeenv==1.8.0
# via pre-commit
numba==0.59.1
numpy==1.24.3
+ # via chex
# via contourpy
+ # via flax
# via h5py
# via jax
# via jaxlib
@@ -103,16 +136,24 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
+ # via optax
+ # via orbax-checkpoint
# via patsy
+ # via pydantic-tensor
# via scipy
# via seaborn
# via shapely
# via statsmodels
+ # via tensorstore
# via tidy3d
# via trimesh
# via xarray
opt-einsum==3.3.0
# via jax
+optax==0.2.2
+ # via flax
+orbax-checkpoint==0.5.15
+ # via flax
packaging==24.0
# via commitizen
# via dask
@@ -123,6 +164,7 @@ packaging==24.0
pandas==2.2.1
# via seaborn
# via statsmodels
+ # via tidy3d
# via xarray
partd==1.4.1
# via dask
@@ -136,10 +178,14 @@ polars==0.20.26
pre-commit==3.7.0
prompt-toolkit==3.0.36
# via questionary
+protobuf==5.27.1
+ # via orbax-checkpoint
pydantic==2.7.1
+ # via pydantic-tensor
# via tidy3d
pydantic-core==2.18.2
# via pydantic
+pydantic-tensor==0.2.0
pygments==2.17.2
# via rich
pyjwt==2.8.0
@@ -157,6 +203,8 @@ pytz==2024.1
pyyaml==6.0.1
# via commitizen
# via dask
+ # via flax
+ # via orbax-checkpoint
# via pre-commit
# via responses
# via tidy3d
@@ -167,11 +215,12 @@ requests==2.31.0
# via tidy3d
responses==0.23.1
# via tidy3d
-rich==12.5.0
+rich==13.7.1
+ # via flax
# via tidy3d
rtree==1.2.0
ruff==0.4.3
-s3transfer==0.5.2
+s3transfer==0.10.1
# via boto3
scipy==1.12.0
# via jax
@@ -190,9 +239,12 @@ six==1.16.0
statsmodels==0.14.2
# via seaborn
sympy==1.12
+tensorstore==0.1.61
+ # via flax
+ # via orbax-checkpoint
termcolor==2.4.0
# via commitizen
-tidy3d==2.6.3
+tidy3d==2.7.0rc2
toml==0.10.2
# via tidy3d
tomli-w==1.0.0
@@ -200,6 +252,7 @@ tomli-w==1.0.0
tomlkit==0.12.4
# via commitizen
toolz==0.12.1
+ # via chex
# via dask
# via partd
trimesh==4.2.0
@@ -208,6 +261,10 @@ typeguard==2.13.3
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0
+ # via chex
+ # via etils
+ # via flax
+ # via orbax-checkpoint
# via pydantic
# via pydantic-core
tzdata==2024.1
@@ -223,4 +280,5 @@ wcwidth==0.2.13
xarray==2024.2.0
# via tidy3d
zipp==3.18.0
+ # via etils
# via importlib-metadata
diff --git a/requirements.lock b/requirements.lock
index 4641252..b9d5c24 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -7,34 +7,44 @@
# all-features: false
# with-sources: false
+absl-py==2.1.0
+ # via chex
+ # via optax
+ # via orbax-checkpoint
annotated-types==0.6.0
# via pydantic
-boto3==1.23.1
+boto3==1.34.123
# via tidy3d
-botocore==1.26.10
+botocore==1.34.123
# via boto3
# via s3transfer
certifi==2021.10.8
# via requests
charset-normalizer==2.0.10
# via requests
-click==8.0.3
+chex==0.1.86
+ # via optax
+click==8.1.7
# via dask
# via tidy3d
cloudpickle==3.0.0
# via dask
-commonmark==0.9.1
- # via rich
contourpy==1.2.0
# via matplotlib
cycler==0.12.1
# via matplotlib
dask==2023.10.1
# via tidy3d
+etils==1.9.1
+ # via orbax-checkpoint
+flax==0.8.4
+ # via tidy3d
fonttools==4.49.0
# via matplotlib
+frozendict==2.4.4
fsspec==2024.2.0
# via dask
+ # via etils
h5netcdf==1.0.2
# via tidy3d
h5py==3.10.0
@@ -45,32 +55,55 @@ idna==3.3
importlib-metadata==6.11.0
# via dask
# via tidy3d
+importlib-resources==6.4.0
+ # via etils
jax==0.4.26
+ # via chex
+ # via flax
+ # via optax
+ # via orbax-checkpoint
jaxlib==0.4.26
+ # via chex
# via jax
+ # via optax
+ # via orbax-checkpoint
jaxtyping==0.2.28
jmespath==1.0.1
# via boto3
# via botocore
+joblib==1.4.2
+ # via tidy3d
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
+markdown-it-py==3.0.0
+ # via rich
matplotlib==3.8.3
# via seaborn
# via tidy3d
+mdurl==0.1.2
+ # via markdown-it-py
ml-dtypes==0.4.0
# via jax
# via jaxlib
+ # via tensorstore
mpmath==1.3.0
# via sympy
+msgpack==1.0.8
+ # via flax
+ # via orbax-checkpoint
msgspec==0.18.6
+nest-asyncio==1.6.0
+ # via orbax-checkpoint
networkx==3.2
numba==0.59.1
numpy==1.24.3
+ # via chex
# via contourpy
+ # via flax
# via h5py
# via jax
# via jaxlib
@@ -79,16 +112,24 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
+ # via optax
+ # via orbax-checkpoint
# via patsy
+ # via pydantic-tensor
# via scipy
# via seaborn
# via shapely
# via statsmodels
+ # via tensorstore
# via tidy3d
# via trimesh
# via xarray
opt-einsum==3.3.0
# via jax
+optax==0.2.2
+ # via flax
+orbax-checkpoint==0.5.15
+ # via flax
packaging==24.0
# via dask
# via h5netcdf
@@ -98,6 +139,7 @@ packaging==24.0
pandas==2.2.1
# via seaborn
# via statsmodels
+ # via tidy3d
# via xarray
partd==1.4.1
# via dask
@@ -106,10 +148,14 @@ patsy==0.5.6
pillow==10.2.0
# via matplotlib
polars==0.20.26
+protobuf==5.27.1
+ # via orbax-checkpoint
pydantic==2.7.1
+ # via pydantic-tensor
# via tidy3d
pydantic-core==2.18.2
# via pydantic
+pydantic-tensor==0.2.0
pygments==2.17.2
# via rich
pyjwt==2.8.0
@@ -126,17 +172,20 @@ pytz==2024.1
# via pandas
pyyaml==6.0.1
# via dask
+ # via flax
+ # via orbax-checkpoint
# via responses
# via tidy3d
-requests==2.27.1
+requests==2.32.3
# via responses
# via tidy3d
responses==0.23.1
# via tidy3d
-rich==12.5.0
+rich==13.7.1
+ # via flax
# via tidy3d
rtree==1.2.0
-s3transfer==0.5.2
+s3transfer==0.10.1
# via boto3
scipy==1.12.0
# via jax
@@ -153,12 +202,16 @@ six==1.16.0
statsmodels==0.14.2
# via seaborn
sympy==1.12
-tidy3d==2.6.3
+tensorstore==0.1.61
+ # via flax
+ # via orbax-checkpoint
+tidy3d==2.7.0rc2
toml==0.10.2
# via tidy3d
tomli-w==1.0.0
# via msgspec
toolz==0.12.1
+ # via chex
# via dask
# via partd
trimesh==4.2.0
@@ -167,6 +220,10 @@ typeguard==2.13.3
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0
+ # via chex
+ # via etils
+ # via flax
+ # via orbax-checkpoint
# via pydantic
# via pydantic-core
tzdata==2024.1
@@ -178,4 +235,5 @@ urllib3==1.26.8
xarray==2024.2.0
# via tidy3d
zipp==3.18.0
+ # via etils
# via importlib-metadata
diff --git a/src/blender_maxwell/assets/structures/primitives/ring.blend b/src/blender_maxwell/assets/structures/primitives/ring.blend
index f7873ef..aee2920 100644
--- a/src/blender_maxwell/assets/structures/primitives/ring.blend
+++ b/src/blender_maxwell/assets/structures/primitives/ring.blend
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:418a8fac57f9a5bcc34811f15b3226faad1ea64b8dde50cc4aa07eb15c0b012f
-size 892163
+oid sha256:41ca7e9fdf54aff1645d4413deb4e21b0e040ba59e65cdd91b0dfb348f2c0d35
+size 947807
diff --git a/src/blender_maxwell/contracts/__init__.py b/src/blender_maxwell/contracts/__init__.py
index 8953491..9a51d72 100644
--- a/src/blender_maxwell/contracts/__init__.py
+++ b/src/blender_maxwell/contracts/__init__.py
@@ -34,6 +34,7 @@ from .bl import (
KeymapItemDef,
ManagedObjName,
PresetName,
+ PropName,
SocketName,
)
from .bl_types import BLEnumStrEnum
@@ -64,6 +65,7 @@ __all__ = [
'KeymapItemDef',
'ManagedObjName',
'PresetName',
+ 'PropName',
'SocketName',
'BLEnumStrEnum',
'BLInstance',
diff --git a/src/blender_maxwell/contracts/bl.py b/src/blender_maxwell/contracts/bl.py
index 1408ddb..3b65696 100644
--- a/src/blender_maxwell/contracts/bl.py
+++ b/src/blender_maxwell/contracts/bl.py
@@ -21,8 +21,9 @@ import bpy
####################
# - Blender Strings
####################
-BLEnumID = str
-SocketName = str
+BLEnumID: typ.TypeAlias = str
+SocketName: typ.TypeAlias = str
+PropName: typ.TypeAlias = str
####################
# - Blender Enums
diff --git a/src/blender_maxwell/contracts/operator_types.py b/src/blender_maxwell/contracts/operator_types.py
index f89088b..b131b91 100644
--- a/src/blender_maxwell/contracts/operator_types.py
+++ b/src/blender_maxwell/contracts/operator_types.py
@@ -44,6 +44,10 @@ class OperatorType(enum.StrEnum):
# Node: ExportDataFile
NodeExportDataFile = enum.auto()
+ # Node: PoleResidueMediumNode
+ NodeFitDispersiveMedium = enum.auto()
+ NodeReleaseDispersiveFit = enum.auto()
+
# Node: Tidy3DWebImporter
NodeLoadCloudSim = enum.auto()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py
index 0828da2..b33a39c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/__init__.py
@@ -34,6 +34,7 @@ from blender_maxwell.contracts import (
OperatorType,
PanelType,
PresetName,
+ PropName,
SocketName,
addon,
)
@@ -62,8 +63,12 @@ from .sim_types import (
BoundCondType,
DataFileFormat,
NewSimCloudTask,
+ Realization,
+ RealizationScalar,
SimAxisDir,
SimFieldPols,
+ SimMetadata,
+ SimRealizations,
SimSpaceAxis,
manual_amp_time,
)
@@ -92,6 +97,7 @@ __all__ = [
'OperatorType',
'PanelType',
'PresetName',
+ 'PropName',
'SocketName',
'addon',
'Icon',
@@ -107,8 +113,12 @@ __all__ = [
'BoundCondType',
'DataFileFormat',
'NewSimCloudTask',
+ 'Realization',
+ 'RealizationScalar',
'SimAxisDir',
'SimFieldPols',
+ 'SimMetadata',
+ 'SimRealizations',
'SimSpaceAxis',
'manual_amp_time',
'NodeCategory',
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
index b5def9c..e174a07 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/bl_socket_types.py
@@ -14,63 +14,101 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import dataclasses
import enum
+import functools
import typing as typ
+from fractions import Fraction
import bpy
+import jax
+import numpy as np
+import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from .socket_types import SocketType
+from .unit_systems import UNITS_BLENDER
log = logger.get(__name__)
BL_SOCKET_DESCR_ANNOT_STRING = ':: '
-@dataclasses.dataclass(kw_only=True, frozen=True)
-class BLSocketInfo:
+class BLSocketInfo(pyd.BaseModel):
+ """Parsed information fully representing a Blender interface socket, enabling translation of values to a format accepted by the Blender socket.
+
+ Notes:
+ All Blender socket values are considered to implicitly be in the unit system `ct.UNITS_BLENDER`.
+ For this reason, we do not declare or handle units / unit systems at the interface.
+
+ This design also prevents spurious (slow) unit conversions in the value-encoding hot path.
+
+ Attributes:
+ has_support: Whether our system explicitly supports mapping to/from this Blender socket.
+ is_preview: Whether this Blender socket is only relevant for driving a preview, not for functionality.
+ socket_type: The corresponding socket type in our system, which the Blender socket can have data pushed from.
+ size: Identifier for the scalar/1D shape of the Blender socket, impacting if/how values can be pushed.
+ mathtype: The mathematical type of the Blender socket value.
+ physical_type: The unit dimension of the Blender socket value.
+ default_value: The Blender socket's own default value, used as the initial socket value in our system.
+ bl_isocket_identifier: The internal identifier of the Blender interface socket in the node tree.
+ This is the **only way** to read/write a particular instance of the socket value through eg. a GeoNodes modifier.
+ """
+
+ model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
+
+ default_value: typ.Any | None
+
has_support: bool
is_preview: bool
+
socket_type: SocketType | None
size: spux.NumberSize1D | None
mathtype: spux.MathType | None
physical_type: spux.PhysicalType | None
- default_value: spux.ScalarUnitlessRealExpr
- bl_isocket_identifier: spux.ScalarUnitlessRealExpr
+ bl_isocket_identifier: str
+
+ @functools.cached_property
+ def unit(self) -> spux.Unit | None:
+ """Deduce the unit of the Blender socket value, by retrieving the Blender units of `self.physical_type`.
+
+ If the socket has no unit dimension, this will also be `None`.
+ """
+ if self.physical_type is not None:
+ return UNITS_BLENDER[self.physical_type]
+ return None
def encode(
- self, raw_value: typ.Any, unit_system: spux.UnitSystem | None
+ self, raw_value: int | Fraction | float | tuple[int | Fraction | float, ...]
) -> typ.Any:
- """Encode a raw value, given a unit system, to be directly writable to a node socket.
+ """Conform a mostly-prepared value, so as to be guaranteed-writable to a node socket.
This encoded form is also guaranteed to support writing to a node socket via a modifier interface.
"""
- # Non-Numerical: Passthrough
- if unit_system is None or self.physical_type is None:
+ MT = spux.MathType
+ # Numerical: Conform to Pure Python Type
+ if self.mathtype is not None:
+ # Coerce jax/np Array -> lists
+ if isinstance(raw_value, np.ndarray | jax.Array):
+ if self.mathtype is MT.Real and isinstance(raw_value.item(0), int):
+ return raw_value.astype(float).flatten().tolist()
+ return raw_value.flatten().tolist()
+ # Coerce int -> float w/Target is Real
+ ## -> The value - modifier - GN path is more strict than properties.
+ if self.mathtype is MT.Real and isinstance(raw_value, int):
+ return float(raw_value)
+
+ # Coerce Fraction -> tuple[int, int] w/Target is Rational
+ if self.mathtype is MT.Rational and isinstance(raw_value, Fraction):
+ return (raw_value.numerator, raw_value.denominator)
+
return raw_value
- # Numerical: Convert to Pure Python Type
- if (
- unit_system is not None
- and self.physical_type is not spux.PhysicalType.NonPhysical
- ):
- unitless_value = spux.scale_to_unit_system(raw_value, unit_system)
- elif isinstance(raw_value, spux.SympyType):
- unitless_value = spux.sympy_to_python(raw_value)
- else:
- unitless_value = raw_value
-
- # Coerce int -> float w/Target is Real
- ## -> The value - modifier - GN path is more strict than properties.
- if self.mathtype is spux.MathType.Real and isinstance(unitless_value, int):
- return float(unitless_value)
-
- return unitless_value
+ # Non-Numerical: Passthrough
+ return raw_value
class BLSocketType(enum.StrEnum):
@@ -403,12 +441,12 @@ class BLSocketType(enum.StrEnum):
## -> The description hint "2D" is the trigger for this.
## -> Ignore the last component to get the effect of "2D".
elif description.startswith('2D'):
- default_value = sp.ImmutableMatrix(tuple(bl_default_value)[:2])
+ default_value = tuple(bl_default_value)[:2]
# 3D/4D: Simple Parse to Sympy Matrix
## -> We don't explicitly check the size.
else:
- default_value = sp.ImmutableMatrix(tuple(bl_default_value))
+ default_value = tuple(bl_default_value)
else:
# Non-Mathematical: Passthrough
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_labels.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_labels.py
index 491356c..c9e6f1e 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_labels.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_labels.py
@@ -14,18 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from .category_types import NodeCategory as NC
+from .category_types import NodeCategory as NC # noqa: N817
NODE_CAT_LABELS = {
# Analysis/
NC.MAXWELLSIM_ANALYSIS: 'Analysis',
NC.MAXWELLSIM_ANALYSIS_MATH: 'Math',
+ # Utilities/
+ NC.MAXWELLSIM_UTILITIES: 'Utilities',
# Inputs/
NC.MAXWELLSIM_INPUTS: 'Inputs',
NC.MAXWELLSIM_INPUTS_SCENE: 'Scene',
NC.MAXWELLSIM_INPUTS_CONSTANTS: 'Constants',
NC.MAXWELLSIM_INPUTS_FILEIMPORTERS: 'File Importers',
NC.MAXWELLSIM_INPUTS_WEBIMPORTERS: 'Web Importers',
+ # Solvers/
+ NC.MAXWELLSIM_SOLVERS: 'Solvers',
# Outputs/
NC.MAXWELLSIM_OUTPUTS: 'Outputs',
NC.MAXWELLSIM_OUTPUTS_FILEEXPORTERS: 'File Exporters',
@@ -39,15 +43,11 @@ NODE_CAT_LABELS = {
# Structures/
NC.MAXWELLSIM_STRUCTURES: 'Structures',
NC.MAXWELLSIM_STRUCTURES_PRIMITIVES: 'Primitives',
- # Bounds/
- NC.MAXWELLSIM_BOUNDS: 'Bounds',
- NC.MAXWELLSIM_BOUNDS_BOUNDCONDS: 'Conds',
# Monitors/
NC.MAXWELLSIM_MONITORS: 'Monitors',
NC.MAXWELLSIM_MONITORS_PROJECTED: 'Projected',
# Simulations/
NC.MAXWELLSIM_SIMS: 'Simulations',
- NC.MAXWELLSIM_SIMS_SIMGRIDAXES: 'Sim Grid Axes',
- # Utilities/
- NC.MAXWELLSIM_UTILITIES: 'Utilities',
+ NC.MAXWELLSIM_SIMS_BOUNDCONDFACES: 'BC Faces',
+ NC.MAXWELLSIM_SIMS_SIMGRIDAXES: 'Grid Axes',
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_types.py
index f05cd76..e63aaf0 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/category_types.py
@@ -27,6 +27,9 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_ANALYSIS = enum.auto()
MAXWELLSIM_ANALYSIS_MATH = enum.auto()
+ # Utilities/
+ MAXWELLSIM_UTILITIES = enum.auto()
+
# Inputs/
MAXWELLSIM_INPUTS = enum.auto()
MAXWELLSIM_INPUTS_SCENE = enum.auto()
@@ -34,6 +37,9 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_INPUTS_FILEIMPORTERS = enum.auto()
MAXWELLSIM_INPUTS_WEBIMPORTERS = enum.auto()
+ # Solvers/
+ MAXWELLSIM_SOLVERS = enum.auto()
+
# Outputs/
MAXWELLSIM_OUTPUTS = enum.auto()
MAXWELLSIM_OUTPUTS_FILEEXPORTERS = enum.auto()
@@ -51,21 +57,15 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_STRUCTURES = enum.auto()
MAXWELLSIM_STRUCTURES_PRIMITIVES = enum.auto()
- # Bounds/
- MAXWELLSIM_BOUNDS = enum.auto()
- MAXWELLSIM_BOUNDS_BOUNDCONDS = enum.auto()
-
# Monitors/
MAXWELLSIM_MONITORS = enum.auto()
MAXWELLSIM_MONITORS_PROJECTED = enum.auto()
# Simulations/
MAXWELLSIM_SIMS = enum.auto()
+ MAXWELLSIM_SIMS_BOUNDCONDFACES = enum.auto()
MAXWELLSIM_SIMS_SIMGRIDAXES = enum.auto()
- # Utilities/
- MAXWELLSIM_UTILITIES = enum.auto()
-
@classmethod
def get_tree(cls):
## TODO: Refactor
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
index 1599f45..3e14daf 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/array.py
@@ -14,16 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import base64
import functools
+import io
import typing as typ
+import jax
+import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.jaxarray import JaxArrayBytes
+from blender_maxwell.utils.lru_method import method_lru
log = logger.get(__name__)
@@ -40,16 +46,28 @@ class ArrayFlow(pyd.BaseModel):
None if unitless.
"""
- model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
+ model_config = pyd.ConfigDict(frozen=True)
- values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
unit: spux.Unit | None = None
-
is_sorted: bool = False
+ ####################
+ # - Array Access
+ ####################
+ jax_bytes: JaxArrayBytes ## Immutable jax.Array, anyone? :)
+
+ @functools.cached_property
+ def values(self) -> jax.Array:
+ """Return the jax array."""
+ with io.BytesIO() as memfile:
+ memfile.write(base64.b64decode(self.jax_bytes.decode('utf-8')))
+ memfile.seek(0)
+ return jnp.load(memfile)
+
####################
# - Computed Properties
####################
+ @method_lru()
def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values)
@@ -70,7 +88,7 @@ class ArrayFlow(pyd.BaseModel):
####################
# - Array Features
####################
- @property
+ @functools.cached_property
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
"""Standardized access to `self.values`."""
return self.values
@@ -80,6 +98,13 @@ class ArrayFlow(pyd.BaseModel):
"""Shape of the contained array."""
return self.values.shape
+ @method_lru(maxsize=32)
+ def _getitem_index(self, i: int) -> typ.Self | spux.SympyExpr:
+ value = self.values[i]
+ if len(value.shape) == 0:
+ return value * self.unit if self.unit is not None else sp.S(value)
+ return ArrayFlow(jax_bytes=value, unit=self.unit, is_sorted=self.is_sorted)
+
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
"""Implement indexing and slicing in a sane way.
@@ -88,16 +113,13 @@ class ArrayFlow(pyd.BaseModel):
"""
if isinstance(subscript, slice):
return ArrayFlow(
- values=self.values[subscript],
+ jax_bytes=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
if isinstance(subscript, int):
- value = self.values[subscript]
- if len(value.shape) == 0:
- return value * self.unit if self.unit is not None else sp.S(value)
- return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
+ return self._getitem_index(subscript)
raise NotImplementedError
@@ -121,7 +143,7 @@ class ArrayFlow(pyd.BaseModel):
## -> However, too-large ints may cause JAX to suffer from an overflow.
## -> Jax works in 32-bit domain by default, for performance.
## -> While it can be adjusted, that would also have tradeoffs.
- ## -> Instead, a quick .n() turns all the big-ints into floats.
+ ## -> Instead, a quick float() turns all the big-ints into floats.
## -> Not super satisfying, but hey - it's all numerical anyway.
a = self.mathtype.sp_symbol_a
rescale_expr = (
@@ -134,11 +156,12 @@ class ArrayFlow(pyd.BaseModel):
# Return ArrayFlow
return ArrayFlow(
- values=values[::-1] if reverse else values,
+ jax_bytes=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
+ # @method_lru()
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value.
@@ -159,26 +182,27 @@ class ArrayFlow(pyd.BaseModel):
# BinSearch for "Right IDX"
## >>> self.values[right_idx] > scaled_value
## >>> self.values[right_idx - 1] < scaled_value
- right_idx = np.searchsorted(self.values, scaled_value, side='left')
+ right_idx = jnp.searchsorted(self.values, scaled_value, side='left')
# Case: Right IDX is Boundary
if right_idx == 0:
- return right_idx
+ return int(right_idx)
if right_idx == len(self.values):
- return right_idx - 1
+ return int(right_idx - 1)
# Find Closest of [Right IDX - 1, Right IDX]
left_val = self.values[right_idx - 1]
right_val = self.values[right_idx]
if (scaled_value - left_val) <= (right_val - scaled_value):
- return right_idx - 1
+ return int(right_idx - 1)
- return right_idx
+ return int(right_idx)
####################
# - Unit Transforms
####################
+ @method_lru()
def correct_unit(self, unit: spux.Unit) -> typ.Self:
"""Simply replace the existing unit with the given one.
@@ -186,8 +210,9 @@ class ArrayFlow(pyd.BaseModel):
corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
- return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
+ return ArrayFlow(jax_bytes=self.values, unit=unit, is_sorted=self.is_sorted)
+ @method_lru()
def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
"""Rescale the `ArrayFlow` to be expressed in the given unit.
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
index d892ece..d34fa17 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/flow_kinds.py
@@ -19,8 +19,8 @@ import functools
import typing as typ
from blender_maxwell.contracts import BLEnumElement
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.staticproperty import staticproperty
log = logger.get(__name__)
@@ -189,26 +189,32 @@ class FlowKind(enum.StrEnum):
####################
# - Class Methods
####################
- ## TODO: Remove this (only events uses it).
- @classmethod
def scale_to_unit_system(
- cls,
- kind: typ.Self,
- flow_obj: spux.SympyExpr,
+ self,
+ flow: typ.Any,
unit_system: spux.UnitSystem,
+ use_jax_array: bool = True,
):
- # log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
- ## TODO: Use a hot-path logger.
- if kind == FlowKind.Value:
- return spux.scale_to_unit_system(
- flow_obj,
- unit_system,
- )
- if kind == FlowKind.Range:
- return flow_obj.rescale_to_unit_system(unit_system)
+ """Perform unit-system scaling per-`FlowKind`."""
+ match self:
+ case FlowKind.Value if isinstance(spux.SympyType):
+ return spux.scale_to_unit_system(
+ flow,
+ unit_system,
+ use_jax_array=use_jax_array,
+ )
- if kind == FlowKind.Params:
- return flow_obj.rescale_to_unit_system(unit_system)
+ case FlowKind.Array | FlowKind.Range:
+ return flow.rescale_to_unit_system(unit_system)
- msg = 'Tried to scale unknown kind'
+ case FlowKind.Func:
+ return flow.scale_to_unit_system(unit_system)
+
+ case FlowKind.Params:
+ return flow
+
+ case FlowKind.Info:
+ return flow.scale_to_unit_system(unit_system)
+
+ msg = f"Applying unit-system scaling to {self} doesn't make sense"
raise ValueError(msg)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
index 58eba20..93f4343 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/info.py
@@ -18,15 +18,19 @@ import dataclasses
import functools
import typing as typ
+import pydantic as pyd
+
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import FrozenDict, frozendict
+from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .lazy_range import RangeFlow
log = logger.get(__name__)
-LabelArray: typ.TypeAlias = list[str]
+LabelArray: typ.TypeAlias = tuple[str, ...]
# IndexArray: Identifies Discrete Dimension Values
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
@@ -36,8 +40,7 @@ LabelArray: typ.TypeAlias = list[str]
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
-@dataclasses.dataclass(frozen=True, kw_only=True)
-class InfoFlow:
+class InfoFlow(pyd.BaseModel):
"""Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
@@ -60,13 +63,20 @@ class InfoFlow:
- **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
"""
+ model_config = pyd.ConfigDict(frozen=True)
+
####################
# - Dimensions: Covariant Index
####################
- dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
- default_factory=dict
+ dims: FrozenDict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
+ default_factory=frozendict
)
+ @functools.cached_property
+ def dims_list(self) -> list[IndexArray]:
+ """Return the dimensional symbols as an ordered list."""
+ return list(self.dims.keys())
+
# Access
@functools.cached_property
def first_dim(self) -> sim_symbols.SimSymbol | None:
@@ -88,12 +98,7 @@ class InfoFlow:
return list(self.dims.keys())[-1]
return None
- def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
- """Retrieve the dimension associated with a particular index."""
- if idx > 0 and idx < len(self.dims) - 1:
- return list(self.dims.keys())[idx]
- return None
-
+ @method_lru()
def dim_by_name(self, dim_name: str, optional: bool = False) -> int | None:
"""The integer axis occupied by the dimension.
@@ -110,6 +115,7 @@ class InfoFlow:
raise ValueError(msg)
# Information By-Dim
+ @method_lru()
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
@@ -120,16 +126,19 @@ class InfoFlow:
"""
return self.dims[dim] is None
+ @method_lru()
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
+ @method_lru()
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim is indexed by a `LabelArray`."""
if dim.mathtype is spux.MathType.Integer:
- return isinstance(self.dims[dim], list)
+ return isinstance(self.dims[dim], tuple)
return False
+ @method_lru()
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the given dim has explicitly uniform indexing.
@@ -141,6 +150,7 @@ class InfoFlow:
dim_idx = self.dims[dim]
return isinstance(dim_idx, RangeFlow) and dim_idx.scaling == 'lin'
+ @method_lru()
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
@@ -158,8 +168,8 @@ class InfoFlow:
####################
## -> Whenever a dimension is deleted, we retain what that index value was.
## -> This proves to be very helpful for clear visualization.
- pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
- default_factory=dict
+ pinned_values: FrozenDict[sim_symbols.SimSymbol, spux.SympyExpr] = (
+ dataclasses.field(default_factory=frozendict)
)
####################
@@ -188,7 +198,7 @@ class InfoFlow:
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
- return len(self.dims) + self.output_shape_len
+ return len(self.dims) + self.output.shape_len
@functools.cached_property
def is_scalar(self) -> tuple[spux.MathType, int, int]:
@@ -216,6 +226,31 @@ class InfoFlow:
####################
# - Operations: Comparison
####################
+ # def broadcast(self, other: typ.Self) -> bool:
+ # """Deduce the output `InfoFlow` by deriving numpy's broadcasting rules for `InfoFlow`, enabling operations between compatible but differently sized `InfoFlow`s."""
+ # # Broadcast Dimensions
+ # new_dims = {}
+ # for _dim_l, _dim_r in itertools.zip_longest(
+ # reversed(self.dims.keys()),
+ # reversed(other.dims.keys()),
+ # fillvalue=None,
+ # ):
+ # dim_l = _dim_l if _dim_l is not None else _dim_r
+ # dim_r = _dim_r if _dim_r is not None else _dim_l
+
+ # if dim_l.compare(dim_r) and self.dims[dim_l] == other.dims[dim_r]:
+ # new_dims |= {dim_l: self.dims[dim_l]}
+
+ # # Broadcast Symbols
+ # return InfoFlow(
+ # dims=new_dims,
+ # output=self.output,
+ # pinned_values=self.pinned_values | other.pinned_values,
+ # )
+
+ # return tuple(reversed(new_shape))
+
+ @method_lru()
def compare_dims_identical(self, other: typ.Self) -> bool:
"""Whether that the quantity and properites of all dimension `SimSymbol`s are "identical".
@@ -226,6 +261,7 @@ class InfoFlow:
for dim_l, dim_r in zip(self.dims, other.dims, strict=True)
)
+ @method_lru()
def compare_addable(
self, other: typ.Self, allow_differing_unit: bool = False
) -> bool:
@@ -239,6 +275,7 @@ class InfoFlow:
other.output, allow_differing_unit=allow_differing_unit
)
+ @method_lru()
def compare_multiplicable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be multiplied (elementwise).
@@ -251,6 +288,7 @@ class InfoFlow:
or self.compare_dims_identical(other)
)
+ @method_lru()
def compare_exponentiable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be exponentiated.
@@ -265,36 +303,54 @@ class InfoFlow:
or self.compare_dims_identical(other)
)
+ ####################
+ # - Operations: General Update
+ ####################
+ @method_lru(maxsize=256)
+ def _update(self, frozen_kwargs: frozendict) -> typ.Self:
+ if not frozen_kwargs:
+ return self
+ return InfoFlow(**(dict(self) | frozen_kwargs))
+
+ def update(self, **kwargs: dict) -> typ.Self:
+ return self._update(frozendict(kwargs))
+
####################
# - Operations: Dimensions
####################
+ @method_lru()
def prepend_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index 0."""
- return InfoFlow(
- dims={dim: dim_idx} | self.dims,
- output=self.output,
- pinned_values=self.pinned_values,
- )
+ return self.update(dims=frozendict({dim: dim_idx} | self.dims))
+ @method_lru()
+ def append_dim(
+ self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
+ ) -> typ.Self:
+ """Insert a new dimension at index -1."""
+ return self.update(dims=frozendict(self.dims | {dim: dim_idx}))
+
+ @method_lru()
def slice_dim(
self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
) -> typ.Self:
"""Slice a dimensional array by-index along a particular dimension."""
- return InfoFlow(
- dims={
- _dim: (
- dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
- if _dim == dim
- else dim_idx
- )
- for _dim, dim_idx in self.dims.items()
- },
- output=self.output,
- pinned_values=self.pinned_values,
+ return self.update(
+ dims=frozendict(
+ {
+ _dim: (
+ dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
+ if _dim == dim
+ else dim_idx
+ )
+ for _dim, dim_idx in self.dims.items()
+ }
+ )
)
+ @method_lru()
def replace_dim(
self,
old_dim: sim_symbols.SimSymbol,
@@ -302,48 +358,66 @@ class InfoFlow:
new_dim_idx: IndexArray,
) -> typ.Self:
"""Replace a dimension entirely, in-place, including symbol and index array."""
- return InfoFlow(
- dims={
- (new_dim if _dim == old_dim else _dim): (
- new_dim_idx if _dim == old_dim else dim_idx
- )
- for _dim, dim_idx in self.dims.items()
- },
- output=self.output,
- pinned_values=self.pinned_values,
+ return self.update(
+ dims=frozendict(
+ {
+ (new_dim if _dim == old_dim else _dim): (
+ new_dim_idx if _dim == old_dim else dim_idx
+ )
+ for _dim, dim_idx in self.dims.items()
+ }
+ )
)
+ @method_lru()
def replace_dims(
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges."""
- return InfoFlow(
- dims={
- dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
- },
- output=self.output,
- pinned_values=self.pinned_values,
+ return self.update(
+ dims=frozendict(
+ {
+ dim: new_dims.get(dim, dim_idx)
+ for dim, dim_idx in self.dim_idx.items()
+ }
+ )
)
+ @method_lru()
def delete_dim(
self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
) -> typ.Self:
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
- new_pin = (
- {dim_to_remove: self.dims[dim_to_remove][pin_idx]}
- if pin_idx is not None
- else {}
- )
- return InfoFlow(
- dims={
- dim: dim_idx
- for dim, dim_idx in self.dims.items()
- if dim != dim_to_remove
- },
- output=self.output,
- pinned_values=self.pinned_values | new_pin,
- )
+ if dim_to_remove in self.dims:
+ dim_idx = self.dims[dim_to_remove]
+ # Deduce Pinned Value
+ if pin_idx is not None:
+ pin_value = {dim_to_remove: dim_idx[pin_idx]}
+ elif (
+ self.has_idx_discrete(dim_to_remove)
+ or self.has_idx_labels(dim_to_remove)
+ ) and len(dim_idx) == 1:
+ pin_value = {dim_to_remove: dim_idx[0]}
+ else:
+ pin_value = {}
+
+ # Delete Dimension
+ return self.update(
+ dims=frozendict(
+ {
+ dim: dim_idx
+ for dim, dim_idx in self.dims.items()
+ if dim != dim_to_remove
+ }
+ ),
+ pinned_values=self.pinned_values | pin_value,
+ )
+
+ msg = 'Dimension to delete is not in the InfoFlow dimensions (to delete: {dim_to_remove}, info={self})'
+ raise ValueError(msg)
+
+ @method_lru()
def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
"""Swap the positions of two dimensions."""
@@ -357,10 +431,10 @@ class InfoFlow:
swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
- return InfoFlow(
- dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
- output=self.output,
- pinned_values=self.pinned_values,
+ return self.update(
+ dims=frozendict(
+ {dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys}
+ ),
)
####################
@@ -368,33 +442,58 @@ class InfoFlow:
####################
def update_output(self, **kwargs) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
- return InfoFlow(
- dims=self.dims,
+ return self.update(
output=self.output.update(**kwargs),
- pinned_values=self.pinned_values,
)
- def operate_output(
- self,
- other: typ.Self,
- op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
- unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
- ) -> spux.SympyExpr:
- """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
- sym_name = sim_symbols.SimSymbolName.Expr
- expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
- unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
- ## TODO: Handle per-cell matrix units?
+ # def map_output(
+ # self,
+ # op: typ.Callable[[spux.SympyExpr], spux.SympyExpr],
+ # unit_op: typ.Callable[[spux.SympyExpr], spux.SympyExpr],
+ # new_domain: sp.Set | None = None,
+ # ) -> spux.SympyExpr:
+ # """Apply an operation to a single `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
+ # sym_name = sim_symbols.SimSymbolName.Expr
+ # expr = op(self.output.sp_symbol_phy)
+ # unit_expr = unit_op(self.output.unit_factor)
- return InfoFlow(
- dims=self.dims,
- output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
- pinned_values=self.pinned_values,
+ # return self.update(
+ # output=sim_symbols.SimSymbol.from_expr(
+ # sym_name, expr, unit_expr, new_domain=new_domain
+ # ),
+ # )
+
+ def scale_to_unit_system(self, unit_system: spux.UnitSystem) -> typ.Self:
+ """Passthrough to `SimSymbol.update()` method on `self.output`."""
+ return self.update(
+ output=self.output.scale_to_unit_system(unit_system),
)
+ # def operate_output(
+ # self,
+ # other: typ.Self,
+ # op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
+ # unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
+ # new_domain: spux.BlessedSet | None = None,
+ # ) -> spux.SympyExpr:
+ # """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
+ # sym_name = sim_symbols.SimSymbolName.Expr
+ # expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
+ # unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
+
+ # return self.update(
+ # output=sim_symbols.SimSymbol.from_expr(
+ # sym_name,
+ # expr,
+ # unit_expr,
+ # new_domain=new_domain,
+ # )
+ # )
+
####################
# - Operations: Fold
####################
+ @functools.cached_property
def fold_last_input(self):
"""Fold the last input dimension into the output."""
last_idx = self.dims[self.last_dim]
@@ -411,12 +510,13 @@ class InfoFlow:
case (_, _):
raise NotImplementedError ## Not yet :)
- return InfoFlow(
- dims={
- dim: dim_idx
- for dim, dim_idx in self.dims.items()
- if dim != self.last_dim
- },
+ return self.update(
+ dims=frozendict(
+ {
+ dim: dim_idx
+ for dim, dim_idx in self.dims.items()
+ if dim != self.last_dim
+ }
+ ),
output=new_output,
- pinned_values=self.pinned_values,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
index 025ef71..e6fa541 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py
@@ -226,8 +226,10 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import FrozenDict, frozendict
+from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .info import InfoFlow
@@ -256,8 +258,10 @@ class FuncFlow(pyd.BaseModel):
model_config = pyd.ConfigDict(frozen=True)
func: LazyFunction
- func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
- func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
+ func_args: tuple[sim_symbols.SimSymbol, ...] = ()
+ func_kwargs: FrozenDict[str, sim_symbols.SimSymbol] = pyd.Field(
+ default_factory=frozendict
+ )
func_output: sim_symbols.SimSymbol | None = None
supports_jax: bool = False
@@ -319,12 +323,11 @@ class FuncFlow(pyd.BaseModel):
####################
# - Realization
####################
+ @method_lru(maxsize=64)
def realize(
self,
params: ParamsFlow,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
disallow_jax: bool = True,
) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
@@ -337,22 +340,21 @@ class FuncFlow(pyd.BaseModel):
"""
if self.supports_jax and not disallow_jax:
return self.func_jax(
- *params.scaled_func_args(symbol_values),
- **params.scaled_func_kwargs(symbol_values),
+ *params.scaled_func_args(self.func_args, symbol_values),
+ **params.scaled_func_kwargs(self.func_kwargs, symbol_values),
)
return self.func(
- *params.scaled_func_args(symbol_values),
- **params.scaled_func_kwargs(symbol_values),
+ *params.scaled_func_args(self.func_args, symbol_values),
+ **params.scaled_func_kwargs(self.func_kwargs, symbol_values),
)
+ @method_lru(maxsize=64)
def realize_as_data(
self,
info: InfoFlow,
params: ParamsFlow,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
- ) -> dict[sim_symbols.SimSymbol, jtyp.Inexact[jtyp.Array, '...']]:
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
+ ) -> frozendict[sim_symbols.SimSymbol, jtyp.Inexact[jtyp.Array, '...']]:
"""Realize as an ordered dictionary mapping each realized `self.dims` entry, with the last entry containing all output data as mapped from the `self.output`."""
data = {}
for dim, dim_idx in info.dims.items():
@@ -386,9 +388,14 @@ class FuncFlow(pyd.BaseModel):
if info.has_idx_labels(dim):
data |= {dim: dim_idx}
- return data | {info.output: self.realize(params, symbol_values=symbol_values)}
+ return frozendict(
+ data
+ | {info.output: self.realize(params, symbol_values=symbol_values)}
+ | {'pinned': info.pinned_values}
+ )
- def realize_partial(
+ @method_lru(maxsize=64)
+ def generate_realizer(
self, params: ParamsFlow
) -> typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
@@ -399,7 +406,7 @@ class FuncFlow(pyd.BaseModel):
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
- By using `realize_partial()`, two things are ensured:
+ By using `generate_realizer()`, two things are ensured:
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
@@ -418,24 +425,36 @@ class FuncFlow(pyd.BaseModel):
return self.func(
*[
func_arg_n(*sym_args, *pre_realized_syms)
- for func_arg_n in params.func_args_n
+ for func_arg_n in params.func_args_n(self.func_args)
],
**{
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
- for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
+ for func_arg_name, func_kwarg_n in params.func_kwargs_n(
+ self.func_kwargs
+ ).items()
},
)
return realizer
+ @method_lru(maxsize=8)
+ def realize_preview(
+ self, params: ParamsFlow
+ ) -> typ.Callable[
+ [int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
+ jtyp.Inexact[jtyp.Array, '...'],
+ ]:
+ """Realize the function value, by replacing unknown symbols with their declared preview values."""
+ return self.realize(params, symbol_values=params.symbol_preview_values)
+
####################
# - Operations
####################
def compose_within(
self,
enclosing_func: LazyFunction,
- enclosing_func_args: list[sim_symbols.SimSymbol] = (),
- enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
+ enclosing_func_args: tuple[sim_symbols.SimSymbol, ...] = (),
+ enclosing_func_kwargs: frozendict[str, sim_symbols.SimSymbol] = frozendict(),
enclosing_func_output: sim_symbols.SimSymbol | None = None,
supports_jax: bool = False,
) -> typ.Self:
@@ -480,18 +499,19 @@ class FuncFlow(pyd.BaseModel):
return FuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
- *list(args[: len(self.func_args)]),
+ *args[: len(self.func_args)],
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
),
- func_args=self.func_args + list(enclosing_func_args),
+ func_args=self.func_args + tuple(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
func_output=enclosing_func_output,
supports_jax=self.supports_jax and supports_jax,
)
+ @method_lru(maxsize=8)
def __or__(
self,
other: typ.Self,
@@ -532,7 +552,7 @@ class FuncFlow(pyd.BaseModel):
def self_func(args, kwargs):
ret = self.func(
- *list(args[: len(self.func_args)]),
+ *args[: len(self.func_args)],
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
)
if not self.is_concatenated:
@@ -543,7 +563,7 @@ class FuncFlow(pyd.BaseModel):
func=lambda *args, **kwargs: (
*self_func(args, kwargs),
other.func(
- *list(args[len(self.func_args) :]),
+ *args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
@@ -553,27 +573,15 @@ class FuncFlow(pyd.BaseModel):
is_concatenated=True,
)
+ @method_lru(maxsize=16)
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
`unit` must be manually guaranteed to be compatible with `self.unit`.
"""
if self.func_output is not None:
- # Retrieve Output Unit
- output_unit = self.func_output.unit
-
- # Compile Efficient Unit-Conversion Function
- a = self.func_output.mathtype.sp_symbol_a
- unit_convert_expr = (
- spux.scale_to_unit(a * output_unit, unit)
- if self.func_output.unit is not None
- else a
- )
- unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
-
- # Compose Unit-Converted FuncFlow
return self.compose_within(
- enclosing_func=unit_convert_func,
+ enclosing_func=spux.unit_scaling_func_n(self.func_output.unit, unit),
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=unit),
)
@@ -581,6 +589,7 @@ class FuncFlow(pyd.BaseModel):
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)
+ @method_lru(maxsize=16)
def scale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
index a7fd668..a8cbf4f 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_range.py
@@ -26,6 +26,8 @@ import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import frozendict
+from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
@@ -47,6 +49,10 @@ class ScalingMode(enum.StrEnum):
@staticmethod
def to_name(v: typ.Self) -> str:
+ """Friendly, single-letter, human-readable column names.
+
+ Must be concise, as there is not a lot of header space to contain these.
+ """
SM = ScalingMode
return {
SM.Lin: 'Linear',
@@ -56,6 +62,7 @@ class ScalingMode(enum.StrEnum):
@staticmethod
def to_icon(_: typ.Self) -> str:
+ """No icons."""
return ''
@@ -135,6 +142,7 @@ class RangeFlow(pyd.BaseModel):
msg = f'RangeFlow is incompatible with SimSymbol {sym}'
raise ValueError(msg)
+ @method_lru()
def to_sym(
self,
sym_name: sim_symbols.SimSymbolName,
@@ -154,6 +162,16 @@ class RangeFlow(pyd.BaseModel):
cols=1,
).set_domain(start=self.realize_start(), end=self.realize_end())
+ @functools.cached_property
+ def symbolic_set(self) -> spux.BlessedSet:
+ if not self.symbols:
+ if self.steps == 0:
+ return spux.BlessedSet(sp.Interval(self.start, self.stop))
+ return spux.BlessedSet(sp.Range(self.start, self.stop + 1))
+
+ msg = 'Cant deduce symbolic set from symbolic RangeFlow'
+ raise ValueError(msg)
+
####################
# - Symbols
####################
@@ -250,6 +268,7 @@ class RangeFlow(pyd.BaseModel):
####################
# - Methods
####################
+ @functools.lru_cache(maxsize=16)
@staticmethod
def try_from_array(
array: ArrayFlow, uniformity_tolerance: float = 1e-9
@@ -269,8 +288,8 @@ class RangeFlow(pyd.BaseModel):
diffs = jnp.diff(array.values)
if (
- jnp.all(jnp.abs(diffs - diffs[0]) < uniformity_tolerance)
- and len(array.values) > 2 # noqa: PLR2004
+ len(array.values) > 2 # noqa: PLR2004
+ and jnp.all(jnp.abs(diffs - diffs[0]) < uniformity_tolerance)
and array.values[0] < array.values[-1]
and array.is_sorted
):
@@ -317,6 +336,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
+ @method_lru(maxsize=16)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
@@ -468,12 +488,11 @@ class RangeFlow(pyd.BaseModel):
####################
# - Realization
####################
+ @method_lru(maxsize=16)
def realize_symbols(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
- ) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
+ ) -> frozendict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
"""Realize **all** input symbols to the `RangeFlow`.
Parameters:
@@ -501,35 +520,32 @@ class RangeFlow(pyd.BaseModel):
raise NotImplementedError(msg)
realized_syms |= {sym: v}
- return realized_syms
+ return frozendict(realized_syms)
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
+ @method_lru(maxsize=16)
def realize_start(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.start.subs(realized_symbols))
+ @method_lru(maxsize=16)
def realize_stop(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.stop.subs(realized_symbols))
+ @method_lru(maxsize=16)
def realize_step_size(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
@@ -544,11 +560,10 @@ class RangeFlow(pyd.BaseModel):
return int(raw_step_size)
return raw_step_size
+ @method_lru(maxsize=16)
def realize(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
@@ -561,7 +576,7 @@ class RangeFlow(pyd.BaseModel):
## TODO: Check symbol values for coverage.
return ArrayFlow(
- values=self.as_func(
+ jax_bytes=self.as_func(
*[
spux.scale_to_unit_system(symbol_values[sym])
for sym in self.sorted_symbols
@@ -584,7 +599,7 @@ class RangeFlow(pyd.BaseModel):
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
raise ValueError(msg)
- @property
+ @functools.cached_property
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
"""Alias for `realize_array.values`."""
return self.realize_array.values
@@ -622,6 +637,7 @@ class RangeFlow(pyd.BaseModel):
####################
# - Units
####################
+ @method_lru(maxsize=32)
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
"""Replaces the unit without rescaling the unitless bounds.
@@ -643,6 +659,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
+ @method_lru(maxsize=32)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
@@ -673,6 +690,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
+ @method_lru(maxsize=32)
def rescale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
index b7a9cb7..570759c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/params.py
@@ -23,8 +23,10 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import FrozenDict, frozendict
+from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .expr_info import ExprInfo
@@ -42,16 +44,16 @@ class ParamsFlow(pyd.BaseModel):
model_config = pyd.ConfigDict(frozen=True)
- arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
- kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
-
- func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
- func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
+ func_args: tuple[spux.SympyExpr, ...] = pyd.Field(default_factory=tuple)
+ func_kwargs: FrozenDict[str, spux.SympyExpr] = pyd.Field(default_factory=frozendict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
- realized_symbols: dict[
+ realized_symbols: FrozenDict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
- ] = pyd.Field(default_factory=dict)
+ ] = pyd.Field(default_factory=frozendict)
+ previewed_symbols: FrozenDict[sim_symbols.SimSymbol, spux.SympyExpr] = pyd.Field(
+ default_factory=frozendict
+ )
####################
# - Symbols
@@ -101,9 +103,10 @@ class ParamsFlow(pyd.BaseModel):
####################
# - JIT'ed Callables for Numerical Function Arguments
####################
- @functools.cached_property
+ @method_lru()
def func_args_n(
self,
+ arg_targets: tuple[sim_symbols.SimSymbol, ...],
) -> list[
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
@@ -112,7 +115,7 @@ class ParamsFlow(pyd.BaseModel):
]:
"""Callable functions for evaluating each `self.func_args` entry numerically.
- Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `self.target_syms`.
+ Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in the passed `arg_targets`.
Notes:
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
@@ -125,14 +128,13 @@ class ParamsFlow(pyd.BaseModel):
target_sym.conform(func_arg, strip_unit=True),
'jax',
)
- for func_arg, target_sym in zip(
- self.func_args, self.arg_targets, strict=True
- )
+ for func_arg, target_sym in zip(self.func_args, arg_targets, strict=True)
]
- @functools.cached_property
+ @method_lru()
def func_kwargs_n(
self,
+ kwarg_targets: frozendict[str, sim_symbols.SimSymbol],
) -> dict[
str,
typ.Callable[
@@ -148,7 +150,7 @@ class ParamsFlow(pyd.BaseModel):
return {
key: sp.lambdify(
self.all_sorted_sp_symbols,
- self.kwarg_targets[key].conform(func_arg, strip_unit=True),
+ kwarg_targets[key].conform(func_arg, strip_unit=True),
'jax',
)
for key, func_arg in self.func_kwargs.items()
@@ -157,9 +159,24 @@ class ParamsFlow(pyd.BaseModel):
####################
# - Realization
####################
+ @functools.cached_property
+ def symbol_preview_values(
+ self,
+ ) -> frozendict[sim_symbols.SimSymbol, sp.Basic] | None:
+ """Provide a dictionary for simplifying all unrealized symbols to their preview values.
+
+ Returns `None` if not all unrealized symbols have preview values.
+ """
+ if all(sym.preview_value is not None for sym in self.all_sorted_symbols):
+ return frozendict(
+ {sym: sym.preview_value_phy for sym in self.all_sorted_symbols}
+ )
+ return None
+
+ @method_lru()
def realize_symbols(
self,
- symbol_values: dict[
+ symbol_values: frozendict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}),
allow_partial: bool = False,
@@ -209,11 +226,11 @@ class ParamsFlow(pyd.BaseModel):
####################
# - Realize Arguments
####################
+ @method_lru(maxsize=8)
def scaled_func_args(
self,
- symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
- {}
- ),
+ arg_targets: tuple[sim_symbols.SimSymbol, ...],
+ symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> list[
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
@@ -237,15 +254,20 @@ class ParamsFlow(pyd.BaseModel):
Parameters:
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
"""
- realized_symbols = list(
+ realized_symbols = tuple(
self.realize_symbols(symbol_values | self.realized_symbols).values()
)
- return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
+ return [
+ func_arg_n(*realized_symbols)
+ for func_arg_n in self.func_args_n(arg_targets)
+ ]
+ @method_lru(maxsize=8)
def scaled_func_kwargs(
self,
- symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
- ) -> dict[
+ kwarg_targets: frozendict[str, sim_symbols.SimSymbol],
+ symbol_values: frozendict[spux.Symbol, spux.SympyExpr] = frozendict(),
+ ) -> frozendict[
str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
"""Realize correctly conformed numerical arguments for `self.func_kwargs`.
@@ -254,14 +276,19 @@ class ParamsFlow(pyd.BaseModel):
"""
realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
- return {
- func_arg_name: func_kwarg_n(**realized_symbols)
- for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
- }
+ return frozendict(
+ {
+ func_arg_name: func_kwarg_n(**realized_symbols)
+ for func_arg_name, func_kwarg_n in self.func_kwargs_n(
+ kwarg_targets
+ ).items()
+ }
+ )
####################
# - Operations
####################
+ @method_lru(maxsize=8)
def __or__(
self,
other: typ.Self,
@@ -272,34 +299,30 @@ class ParamsFlow(pyd.BaseModel):
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
"""
return ParamsFlow(
- arg_targets=self.arg_targets + other.arg_targets,
- kwarg_targets=self.kwarg_targets | other.kwarg_targets,
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
symbols=self.symbols | other.symbols,
realized_symbols=self.realized_symbols | other.realized_symbols,
)
+ @method_lru(maxsize=8)
def compose_within(
self,
- enclosing_arg_targets: list[sim_symbols.SimSymbol] = (),
- enclosing_kwarg_targets: list[sim_symbols.SimSymbol] = (),
- enclosing_func_args: list[spux.SympyExpr] = (),
- enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
+ enclosing_func_args: tuple[spux.SympyExpr, ...] = (),
+ enclosing_func_kwargs: frozendict[str, spux.SympyExpr] = frozendict(),
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
) -> typ.Self:
return ParamsFlow(
- arg_targets=self.arg_targets + list(enclosing_arg_targets),
- kwarg_targets=self.kwarg_targets | dict(enclosing_kwarg_targets),
- func_args=self.func_args + list(enclosing_func_args),
+ func_args=self.func_args + tuple(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols,
realized_symbols=self.realized_symbols,
)
+ @method_lru(maxsize=8)
def realize_partial(
self,
- symbol_values: dict[
+ symbol_values: frozendict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
],
) -> typ.Self:
@@ -316,11 +339,9 @@ class ParamsFlow(pyd.BaseModel):
Raises:
ValueError: If any symbol in `symbol_values`
"""
- syms = set(symbol_values.keys())
+ syms = frozenset(symbol_values.keys())
if syms.issubset(self.symbols) or not syms:
return ParamsFlow(
- arg_targets=self.arg_targets,
- kwarg_targets=self.kwarg_targets,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
symbols=self.symbols - syms,
@@ -333,7 +354,7 @@ class ParamsFlow(pyd.BaseModel):
# - Generate ExprSocketDef
####################
@functools.cached_property
- def sym_expr_infos(self) -> dict[str, ExprInfo]:
+ def sym_expr_infos(self) -> frozendict[str, ExprInfo]:
"""Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_signals.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_signals.py
index 6c3f54b..63915c3 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_signals.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_signals.py
@@ -31,9 +31,10 @@ class FlowSignal(enum.StrEnum):
"""
+ NoFlow = enum.auto()
FlowInitializing = enum.auto()
FlowPending = enum.auto()
- NoFlow = enum.auto()
+ # FlowError = enum.auto()
@classmethod
def all(cls) -> set[typ.Self]:
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py
index fc12eeb..68c7075 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py
@@ -33,8 +33,12 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
ReduceMath = enum.auto()
TransformMath = enum.auto()
- # Inputs
+ # Utilities
+ Combine = enum.auto()
WaveConstant = enum.auto()
+ ViewText = enum.auto()
+
+ # Inputs
Scene = enum.auto()
## Inputs / Constants
ExprConstant = enum.auto()
@@ -48,6 +52,11 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
DataFileImporter = enum.auto()
Tidy3DFileImporter = enum.auto()
+ # Solvers
+ FDTDSolver = enum.auto()
+ ModeSolver = enum.auto()
+ EMESolver = enum.auto()
+
# Outputs
Viewer = enum.auto()
## Outputs / File Exporters
@@ -82,7 +91,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
DebyeMedium = enum.auto()
## Mediums / Non-Linearities
AddNonLinearity = enum.auto()
- ChiThreeSusceptibilityNonLinearity = enum.auto()
+ ChiThreeSuscepNonLinearity = enum.auto()
TwoPhotonAbsorptionNonLinearity = enum.auto()
KerrNonLinearity = enum.auto()
@@ -97,13 +106,6 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
CylinderStructure = enum.auto()
PolySlabStructure = enum.auto()
- # Bounds
- BoundConds = enum.auto()
- ## Bounds / Bound Conds
- PMLBoundCond = enum.auto()
- BlochBoundCond = enum.auto()
- AdiabAbsorbBoundCond = enum.auto()
-
# Monitors
EHFieldMonitor = enum.auto()
PowerFluxMonitor = enum.auto()
@@ -115,15 +117,16 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
KSpaceNearFieldProjectionMonitor = enum.auto()
# Sims
- Combine = enum.auto()
SimDomain = enum.auto()
FDTDSim = enum.auto()
SimGrid = enum.auto()
- ## Sims / Sim Grid Axis
- AutomaticSimGridAxis = enum.auto()
+ BoundConds = enum.auto()
+ ## Sims / Bound Conds
+ PMLBoundCond = enum.auto()
+ BlochBoundCond = enum.auto()
+ AdiabAbsorbBoundCond = enum.auto()
+ ## Sims / Grid Axes
+ AutoSimGridAxis = enum.auto()
ManualSimGridAxis = enum.auto()
UniformSimGridAxis = enum.auto()
ArraySimGridAxis = enum.auto()
-
- # Utilities
- Separate = enum.auto()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
index 9765991..a8bf2bf 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/sim_types.py
@@ -18,19 +18,22 @@
import dataclasses
import enum
+import functools
import typing as typ
+from fractions import Fraction
from pathlib import Path
import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import polars as pl
+import pydantic as pyd
import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud
+from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
-from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow
@@ -668,3 +671,56 @@ class DataFileFormat(enum.StrEnum):
E.Txt: True, ## Use # Comments
E.TxtGz: True, ## Same as Txt
}[self]
+
+
+####################
+# - Encode/Decode Metadata
+####################
+RealizationScalar: typ.TypeAlias = int | float
+Realization: typ.TypeAlias = (
+ RealizationScalar
+ | tuple[RealizationScalar, ...]
+ | tuple[tuple[RealizationScalar, ...], ...]
+)
+
+
+class SimRealizations(pyd.BaseModel):
+ """Encodes the realized values of symbols that were used to generate a particular simulation."""
+
+ model_config = pyd.ConfigDict(frozen=True)
+
+ syms: tuple[sim_symbols.SimSymbol, ...] = ()
+ vals: tuple[Realization, ...] = ()
+
+ @pyd.model_validator(mode='after')
+ def syms_vals_eq_length(self) -> typ.Self:
+ """Ensure that `self.syms` and `self.vals` are of equal length."""
+ if len(self.syms) != len(self.vals):
+ msg = f"'syms' and 'vals' are of differing length (syms={self.syms}, vals={self.vals})"
+ raise ValueError(msg)
+
+ return self
+
+
+class SimMetadata(pyd.BaseModel):
+ """Encodes simulation metadata."""
+
+ model_config = pyd.ConfigDict(frozen=True)
+
+ sim_metadata_version: str = '0.1.0'
+ realizations: SimRealizations = SimRealizations()
+
+ @staticmethod
+ def from_sim(sim: td.Simulation | td.SimulationData) -> typ.Self:
+ """Deduce simulation metadata from a simulation / simulation data."""
+ if 'sim_metadata_version' in sim.attrs:
+ ## TODO: Semantic versioning comparison
+ return SimMetadata(**sim.attrs)
+ return SimMetadata()
+
+ @functools.cached_property
+ def syms_vals(
+ self,
+ ) -> tuple[tuple[sim_symbols.SimSymbol, ...], tuple[Realization, ...]]:
+ """Deduce simulation metadata from a simulation / simulation data."""
+ return (self.realizations.syms, self.realizations.vals)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
index ce6259b..dcef706 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/unit_systems.py
@@ -27,6 +27,7 @@ Attributes:
import typing as typ
import sympy.physics.units as spu
+from frozendict import frozendict
from blender_maxwell.utils import sympy_extra as spux
@@ -34,48 +35,52 @@ from blender_maxwell.utils import sympy_extra as spux
# - Unit Systems
####################
_PT: typ.TypeAlias = spux.PhysicalType
-UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
- # Global
- _PT.Time: spu.picosecond,
- _PT.Freq: spux.terahertz,
- _PT.AngFreq: spu.radian * spux.terahertz,
- # Cartesian
- _PT.Length: spu.micrometer,
- _PT.Area: spu.micrometer**2,
- _PT.Volume: spu.micrometer**3,
- # Energy
- _PT.PowerFlux: spu.watt / spu.um**2,
- # Electrodynamics
- _PT.CurrentDensity: spu.ampere / spu.um**2,
- _PT.Conductivity: spu.siemens / spu.um,
- _PT.EField: spu.volt / spu.um,
- _PT.HField: spu.ampere / spu.um,
- # Mechanical
- _PT.Vel: spu.um / spu.second,
- _PT.Accel: spu.um / spu.second,
- _PT.Mass: spu.microgram,
- _PT.Force: spux.micronewton,
- # Luminal
- # Optics
-} ## TODO: Load (dynamically?) from addon preferences
+UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | frozendict(
+ {
+ # Global
+ _PT.Time: spu.picosecond,
+ _PT.Freq: spux.terahertz,
+ _PT.AngFreq: spu.radian * spux.terahertz,
+ # Cartesian
+ _PT.Length: spu.micrometer,
+ _PT.Area: spu.micrometer**2,
+ _PT.Volume: spu.micrometer**3,
+ # Energy
+ _PT.PowerFlux: spu.watt / spu.um**2,
+ # Electrodynamics
+ _PT.CurrentDensity: spu.ampere / spu.um**2,
+ _PT.Conductivity: spu.siemens / spu.um,
+ _PT.EField: spu.volt / spu.um,
+ _PT.HField: spu.ampere / spu.um,
+ # Mechanical
+ _PT.Vel: spu.um / spu.second,
+ _PT.Accel: spu.um / spu.second,
+ _PT.Mass: spu.microgram,
+ _PT.Force: spux.micronewton,
+ # Luminal
+ # Optics
+ }
+) ## TODO: Load (dynamically?) from addon preferences
-UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
- # Global
- # Cartesian
- _PT.Length: spu.um,
- _PT.Area: spu.um**2,
- _PT.Volume: spu.um**3,
- # Mechanical
- _PT.Vel: spu.um / spu.second,
- _PT.Accel: spu.um / spu.second,
- # Energy
- _PT.PowerFlux: spu.watt / spu.um**2,
- # Electrodynamics
- _PT.CurrentDensity: spu.ampere / spu.um**2,
- _PT.Conductivity: spu.siemens / spu.um,
- _PT.EField: spu.volt / spu.um,
- _PT.HField: spu.ampere / spu.um,
- # Luminal
- # Optics
- ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
-}
+UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | frozendict(
+ {
+ # Global
+ # Cartesian
+ _PT.Length: spu.um,
+ _PT.Area: spu.um**2,
+ _PT.Volume: spu.um**3,
+ # Mechanical
+ _PT.Vel: spu.um / spu.second,
+ _PT.Accel: spu.um / spu.second,
+ # Energy
+ _PT.PowerFlux: spu.watt / spu.um**2,
+ # Electrodynamics
+ _PT.CurrentDensity: spu.ampere / spu.um**2,
+ _PT.Conductivity: spu.siemens / spu.um,
+ _PT.EField: spu.volt / spu.um,
+ _PT.HField: spu.ampere / spu.um,
+ # Luminal
+ # Optics
+ ## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
+ }
+)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py
index 96caac7..d259e25 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_mesh.py
@@ -18,6 +18,7 @@ import contextlib
import bmesh
import bpy
+import jax
import numpy as np
from blender_maxwell.utils import logger
@@ -127,8 +128,18 @@ class ManagedBLMesh(base.ManagedObj):
####################
# - BLMesh Management
####################
- def bl_object(self, location: tuple[float, float, float] = (0, 0, 0)):
- """Returns the managed blender object."""
+ def bl_object(
+ self, location: np.ndarray | jax.Array | tuple[float, float, float] = (0, 0, 0)
+ ):
+ """Returns the managed blender object, centered at the given location."""
+ if isinstance(location, np.ndarray | jax.Array):
+ if isinstance(location.item(0), int):
+ center = tuple(location.astype(float).flatten().tolist())
+ else:
+ center = tuple(location.flatten().tolist())
+ else:
+ center = tuple([float(el) for el in location])
+
# Create Object w/Appropriate Data Block
if not (bl_object := bpy.data.objects.get(self.name)):
log.info(
@@ -143,7 +154,7 @@ class ManagedBLMesh(base.ManagedObj):
)
managed_collection().objects.link(bl_object)
- for i, coord in enumerate(location):
+ for i, coord in enumerate(center):
if bl_object.location[i] != coord:
bl_object.location[i] = coord
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
index 4242efd..ffa5cd8 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_modifier.py
@@ -19,8 +19,9 @@
import typing as typ
import bpy
+import jax
+import numpy as np
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import bl_socket_map
@@ -41,13 +42,10 @@ class ModifierAttrsNODES(typ.TypedDict):
Attributes:
node_group: The GeoNodes group to use in the modifier.
- unit_system: The unit system used by the GeoNodes output.
- Generally, `ct.UNITS_BLENDER` is a good choice.
inputs: Values to associate with each GeoNodes interface socket name.
"""
node_group: bpy.types.GeometryNodeTree
- unit_system: UnitSystem
inputs: dict[ct.SocketName, typ.Any]
@@ -122,7 +120,6 @@ def write_modifier_geonodes(
## -> TODO: A special case isn't clean enough.
bl_modifier[iface_id] = socket_infos[socket_name].encode(
raw_value=modifier_attrs['inputs'][socket_name],
- unit_system=modifier_attrs['unit_system'],
)
modifier_altered = True
## TODO: More fine-grained alterations?
@@ -230,7 +227,7 @@ class ManagedBLModifier(base.ManagedObj):
self,
modifier_type: ct.BLModifierType,
modifier_attrs: ModifierAttrs,
- location: tuple[float, float, float] = (0, 0, 0),
+ location: np.ndarray | jax.Array | tuple[float, float, float] = (0, 0, 0),
):
"""Creates a new modifier for the current `bl_object`.
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_text.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_text.py
new file mode 100644
index 0000000..c999ce2
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/managed_objs/managed_bl_text.py
@@ -0,0 +1,186 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Declares `ManagedBLText`."""
+
+import time
+import typing as typ
+
+import bpy
+import matplotlib.axis as mpl_ax
+import numpy as np
+
+from blender_maxwell.utils import image_ops, logger
+
+from .. import contracts as ct
+from . import base
+
+log = logger.get(__name__)
+
+AREA_TYPE = 'IMAGE_EDITOR'
+SPACE_TYPE = 'IMAGE_EDITOR'
+
+
+####################
+# - Managed BL Image
+####################
+class ManagedBLText(base.ManagedObj):
+ """Represents a Blender Image datablock, encapsulating various useful interactions with it.
+
+ Attributes:
+ name: The name of the image.
+ """
+
+ managed_obj_type = ct.ManagedObjType.ManagedBLText
+ _bl_image_name: str
+
+ def __init__(self, name: str, prev_name: str | None = None):
+ if prev_name is not None:
+ self._bl_image_name = prev_name
+ else:
+ self._bl_image_name = name
+
+ self.name = name
+
+ @property
+ def name(self):
+ return self._bl_image_name
+
+ @name.setter
+ def name(self, value: str):
+ log.info(
+ 'Changing ManagedBLText from "%s" to "%s"',
+ self.name,
+ value,
+ )
+ existing_bl_image = bpy.data.images.get(self.name)
+
+ # No Existing Image: Set Value to Name
+ if existing_bl_image is None:
+ self._bl_image_name = value
+
+ # Existing Image: Rename to New Name
+ else:
+ existing_bl_image.name = value
+ self._bl_image_name = value
+
+ # Check: Blender Rename -> Synchronization Error
+ ## -> We can't do much else than report to the user & free().
+ if existing_bl_image.name != self._bl_image_name:
+ log.critical(
+ 'BLImage: Failed to set name of %s to %s, as %s already exists.'
+ )
+ self._bl_image_name = existing_bl_image.name
+ self.free()
+
+ def free(self):
+ bl_image = bpy.data.images.get(self.name)
+ if bl_image is not None:
+ log.debug('Freeing ManagedBLText "%s"', self.name)
+ bpy.data.images.remove(bl_image)
+
+ ####################
+ # - Managed Object Management
+ ####################
+ def bl_image(
+ self,
+ width_px: int,
+ height_px: int,
+ color_model: typ.Literal['RGB', 'RGBA'],
+ dtype: typ.Literal['uint8', 'float32'],
+ ):
+ """Returns the managed blender image.
+
+ If the requested image properties are different from the image's, then delete the old image make a new image with correct geometry.
+ """
+ channels = 4 if color_model == 'RGBA' else 3
+
+ # Remove Image (if mismatch)
+ bl_image = bpy.data.images.get(self.name)
+ if bl_image is not None and (
+ bl_image.size[0] != width_px
+ or bl_image.size[1] != height_px
+ or bl_image.channels != channels
+ or bl_image.is_float ^ (dtype == 'float32')
+ ):
+ self.free()
+
+ # Create Image w/Geometry (if none exists)
+ bl_image = bpy.data.images.get(self.name)
+ if bl_image is None:
+ bl_image = bpy.data.images.new(
+ self.name,
+ width=width_px,
+ height=height_px,
+ float_buffer=dtype == 'float32',
+ )
+
+ # Enable Fake User
+ bl_image.use_fake_user = True
+
+ return bl_image
+
+ ####################
+ # - Editor UX Manipulation
+ ####################
+ @classmethod
+ def preview_area(cls) -> bpy.types.Area | None:
+ """Deduces a Blender UI area that can be used for image preview.
+
+ Returns:
+ A Blender UI area, if an appropriate one is visible; else `None`,
+ """
+ valid_areas = [
+ area for area in bpy.context.screen.areas if area.type == AREA_TYPE
+ ]
+ if valid_areas:
+ return valid_areas[0]
+ return None
+
+ @classmethod
+ def preview_space(cls) -> bpy.types.SpaceProperties | None:
+ """Deduces a Blender UI space, within `self.preview_area`, that can be used for image preview.
+
+ Returns:
+ A Blender UI space within `self.preview_area`, if it isn't None; else, `None`.
+ """
+ preview_area = cls.preview_area()
+ if preview_area is not None:
+ valid_spaces = [
+ space for space in preview_area.spaces if space.type == SPACE_TYPE
+ ]
+ if valid_spaces:
+ return valid_spaces[0]
+ return None
+ return None
+
+ ####################
+ # - Methods
+ ####################
+ def bl_select(self) -> None:
+ """Selects the image by loading it into an on-screen UI area/space.
+
+ Notes:
+ The image must already be available, else nothing will happen.
+ """
+ bl_image = bpy.data.images.get(self.name)
+ if bl_image is not None:
+ self.preview_space().image = bl_image
+
+ @classmethod
+ def hide_preview(cls) -> None:
+ """Deselects the image by loading `None` into the on-screen UI area/space."""
+ cls.preview_space().image = None
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py
index 6bea59f..8257e73 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/filter.py
@@ -15,6 +15,7 @@
# along with this program. If not, see .
import enum
+import functools
import typing as typ
import jax.lax as jlax
@@ -36,6 +37,8 @@ class FilterOperation(enum.StrEnum):
PinLen1: Remove a len(1) dimension.
Pin: Remove a len(n) dimension by selecting a particular index.
Swap: Swap the positions of two dimensions.
+ ZScore15: The Z-score threshold of values
+ ZScore30: The peak-to-peak along an axis.
"""
# Slice
@@ -47,14 +50,19 @@ class FilterOperation(enum.StrEnum):
Pin = enum.auto()
PinIdx = enum.auto()
- # Dimension
+ # Swizzle
Swap = enum.auto()
+ # Axis Filter
+ ZScore15 = enum.auto()
+ ZScore30 = enum.auto()
+
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name for a physical type."""
FO = FilterOperation
return {
# Slice
@@ -64,15 +72,20 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: 'a[0] → a',
FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'a[i] → a',
- # Reinterpret
+ # Swizzle
FO.Swap: 'a₁ ↔ a₂',
+ # Axis Filter
+ # FO.ZScore15: 'a[v₁:v₂] ∈ σ[1.5]',
+ # FO.ZScore30: 'a[v₁:v₂] ∈ σ[1.5]',
}[value]
@staticmethod
- def to_icon(value: typ.Self) -> str:
+ def to_icon(_: typ.Self) -> str:
+ """No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
FO = FilterOperation
return (
str(self),
@@ -82,54 +95,83 @@ class FilterOperation(enum.StrEnum):
i,
)
+ @staticmethod
+ def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
+ """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
+
+ Returns a `bpy.props.EnumProperty.items`-compatible list.
+ """
+ return [
+ operation.bl_enum_element(i)
+ for i, operation in enumerate(FilterOperation.by_info(info))
+ ]
+
####################
# - Ops from Info
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
- operations = []
+ ops = []
- # Slice
if info.dims:
- operations.append(FO.SliceIdx)
+ # Slice
+ ops += [FO.SliceIdx]
- # Pin
- ## PinLen1
- ## -> There must be a dimension with length 1.
- if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
- operations.append(FO.PinLen1)
+ # Pin
+ ## PinLen1
+ ## -> There must be a dimension with length 1.
+ if 1 in [
+ len(dim_idx) for dim_idx in info.dims.values() if dim_idx is not None
+ ]:
+ ops += [FO.PinLen1]
- ## Pin | PinIdx
- ## -> There must be a dimension, full stop.
- if info.dims:
- operations += [FO.Pin, FO.PinIdx]
+ # Pin
+ ## -> There must be a dimension, full stop.
+ ops += [FO.Pin, FO.PinIdx]
- # Reinterpret
- ## Swap
- ## -> There must be at least two dimensions.
- if len(info.dims) >= 2: # noqa: PLR2004
- operations.append(FO.Swap)
+ # Swizzle
+ ## Swap
+ ## -> There must be at least two dimensions to swap between.
+ if len(info.dims) >= 2: # noqa: PLR2004
+ ops += [FO.Swap]
- return operations
+ # Axis Filter
+ ## ZScore
+ ## -> Subjectively, it makes little sense with less than 5 numbers.
+ ## -> Mathematically valid (I suppose) for 2. But not so useful.
+ # if any(
+ # (dim.has_idx_discrete(dim) or dim.has_idx_labels(dim))
+ # and len(dim_idx) > 5 # noqa: PLR2004
+ # for dim, dim_idx in info.dims.items()
+ # ):
+ # ops += [FO.ZScore15, FO.ZScore30]
+
+ return ops
####################
# - Computed Properties
####################
- @property
+ @functools.cached_property
def func_args(self) -> list[sim_symbols.SimSymbol]:
FO = FilterOperation
return {
# Pin
FO.Pin: [sim_symbols.idx(None)],
FO.PinIdx: [sim_symbols.idx(None)],
+ # Swizzle
+ ## -> Swap: JAX requires that swap dims be baked into the function.
+ # Axis Filter
+ # FO.ZScore15: [sim_symbols.idx(None)],
+ # FO.ZScore30: [sim_symbols.idx(None)],
}.get(self, [])
####################
# - Methods
####################
- @property
+ @functools.cached_property
def num_dim_inputs(self) -> None:
+ """Number of dimensions required as inputs to the operation's function."""
FO = FilterOperation
return {
# Slice
@@ -139,11 +181,15 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: 1,
FO.Pin: 1,
FO.PinIdx: 1,
- # Reinterpret
+ # Swizzle
FO.Swap: 2,
+ # Axis Filter
+ # FO.ZScore15: 1,
+ # FO.ZScore30: 1,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
+ """The valid dimensions that can be selected between, fo each of the"""
FO = FilterOperation
match self:
# Slice
@@ -171,30 +217,20 @@ class FilterOperation(enum.StrEnum):
case FO.Swap:
return info.dims
+ # TODO: ZScore
+
return []
- def are_dims_valid(
- self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
- ) -> bool:
- """Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
- if self.num_dim_inputs == 1:
- return dim_0 in self.valid_dims(info)
-
- if self.num_dim_inputs == 2: # noqa: PLR2004
- valid_dims = self.valid_dims(info)
- return dim_0 in valid_dims and dim_1 in valid_dims
-
- return False
-
####################
- # - UI
+ # - Implementations
####################
def jax_func(
self,
axis_0: int | None,
- axis_1: int | None,
+ axis_1: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
+ """Implements the identified filtering using `jax`."""
FO = FilterOperation
return {
# Pin
@@ -210,13 +246,65 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
+ # TODO: Axis Filters
+ ## -> The jnp.compress() function is ideal for this kind of thing.
+ ## -> The difficulty is that jit() requires output size to be known.
+ ## -> One can set the size= parameter of compress.
+ ## -> But how do we determine that?
}[self]
+ ####################
+ # - Transforms
+ ####################
+ def transform_func(
+ self,
+ func: ct.FuncFlow,
+ axis_0: int,
+ axis_1: int | None = None,
+ slice_tuple: tuple[int, int, int] | None = None,
+ ) -> ct.FuncFlow | None:
+ """Transform input function according to the current operation and output info characterization."""
+ FO = FilterOperation
+ match self:
+ # Slice
+ case FO.Slice | FO.SliceIdx if axis_0 is not None:
+ return func.compose_within(
+ self.jax_func(axis_0, slice_tuple=slice_tuple),
+ enclosing_func_output=func.func_output,
+ supports_jax=True,
+ )
+
+ # Pin
+ case FO.PinLen1 if axis_0 is not None:
+ return func.compose_within(
+ self.jax_func(axis_0),
+ enclosing_func_output=func.func_output,
+ supports_jax=True,
+ )
+
+ case FO.Pin | FO.PinIdx if axis_0 is not None:
+ return func.compose_within(
+ self.jax_func(axis_0),
+ enclosing_func_args=[sim_symbols.idx(None)],
+ enclosing_func_output=func.func_output,
+ supports_jax=True,
+ )
+
+ # Swizzle
+ case FO.Swap if axis_0 is not None and axis_1 is not None:
+ return func.compose_within(
+ self.jax_func(axis_0, axis_1),
+ enclosing_func_output=func.func_output,
+ supports_jax=True,
+ )
+
+ return None
+
def transform_info(
self,
info: ct.InfoFlow,
dim_0: sim_symbols.SimSymbol,
- dim_1: sim_symbols.SimSymbol,
+ dim_1: sim_symbols.SimSymbol | None = None,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
@@ -225,9 +313,10 @@ class FilterOperation(enum.StrEnum):
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
- FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
+ FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=0),
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
+ # TODO: Axis Filters
}[self]()
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py
index d708561..2054a35 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/map.py
@@ -14,11 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements map operations for the `MapNode`."""
+
import enum
import typing as typ
import jax.numpy as jnp
+import jaxtyping as jtyp
import sympy as sp
+import sympy.physics.units as spu
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@@ -27,6 +31,9 @@ from .. import contracts as ct
log = logger.get(__name__)
+MT = spux.MathType
+PT = spux.PhysicalType
+
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
@@ -54,7 +61,8 @@ class MapOperation(enum.StrEnum):
SvdVals: Compute the singular values vector of the input matrix.
Inv: Compute the inverse matrix of the input matrix.
Tra: Compute the transpose matrix of the input matrix.
- Qr: Compute the QR-factorized matrices of the input matrix.
+ QR_Q: Compute the QR-factorized matrices of the input matrix, and extract the 'Q' component.
+ QR_R: Compute the QR-factorized matrices of the input matrix, and extract the 'R' component.
Chol: Compute the Cholesky-factorized matrices of the input matrix.
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
@@ -64,6 +72,7 @@ class MapOperation(enum.StrEnum):
Imag = enum.auto()
Abs = enum.auto()
Sq = enum.auto()
+ Reciprocal = enum.auto()
Sqrt = enum.auto()
InvSqrt = enum.auto()
Cos = enum.auto()
@@ -85,16 +94,17 @@ class MapOperation(enum.StrEnum):
SvdVals = enum.auto()
Inv = enum.auto()
Tra = enum.auto()
- Qr = enum.auto()
- Chol = enum.auto()
- Svd = enum.auto()
+ QR_Q = enum.auto()
+ QR_R = enum.auto()
+ # Chol = enum.auto()
+ # Svd = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
- """A human-readable UI-oriented name for a physical type."""
+ """A human-readable UI-oriented name."""
MO = MapOperation
return {
# By Number
@@ -103,6 +113,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: '|v|',
MO.Sq: 'v²',
MO.Sqrt: '√v',
+ MO.Reciprocal: '1/v',
MO.InvSqrt: '1/√v',
MO.Cos: 'cos v',
MO.Sin: 'sin v',
@@ -123,9 +134,10 @@ class MapOperation(enum.StrEnum):
MO.SvdVals: 'svdvals V',
MO.Inv: 'V⁻¹',
MO.Tra: 'Vt',
- MO.Qr: 'qr V',
- MO.Chol: 'chol V',
- MO.Svd: 'svd V',
+ MO.QR_Q: 'qr[Q] V',
+ MO.QR_R: 'qr[R] V',
+ # MO.Chol: 'chol V',
+ # MO.Svd: 'svd V',
}[value]
@staticmethod
@@ -144,75 +156,105 @@ class MapOperation(enum.StrEnum):
i,
)
- ####################
- # - Ops from Shape
- ####################
@staticmethod
- def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
- ## TODO: By info, not shape.
- ## TODO: Check valid domains/mathtypes for some functions.
- MO = MapOperation
- element_ops = [
- MO.Real,
- MO.Imag,
- MO.Abs,
- MO.Sq,
- MO.Sqrt,
- MO.InvSqrt,
- MO.Cos,
- MO.Sin,
- MO.Tan,
- MO.Acos,
- MO.Asin,
- MO.Atan,
- MO.Sinc,
+ def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
+ """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`.
+
+ Returns a `bpy.props.EnumProperty.items`-compatible list.
+ """
+ return [
+ operation.bl_enum_element(i)
+ for i, operation in enumerate(MapOperation.from_info(info))
]
- match (info.output.rows, info.output.cols):
- case (1, 1):
- return element_ops
+ ####################
+ # - Derivation
+ ####################
+ @staticmethod
+ def from_info(info: ct.InfoFlow) -> list[typ.Self]:
+ """Derive valid mapping operations from the `InfoFlow` of the operand."""
+ MO = MapOperation
+ ops = []
- case (_, 1):
- return [*element_ops, MO.Norm2]
+ # Real/Imag
+ if info.output.mathtype is MT.Complex:
+ ops += [MO.Real, MO.Imag]
- case (rows, cols) if rows == cols:
- ## TODO: Check hermitian/posdef for cholesky.
- ## - Can we even do this with just the output symbol approach?
- return [
- *element_ops,
- MO.Det,
- MO.Cond,
- MO.NormFro,
- MO.Rank,
- MO.Diag,
- MO.EigVals,
- MO.SvdVals,
- MO.Inv,
- MO.Tra,
- MO.Qr,
- MO.Chol,
- MO.Svd,
- ]
+ # Absolute Value
+ ops += [MO.Abs]
- case (rows, cols):
- return [
- *element_ops,
- MO.Cond,
- MO.NormFro,
- MO.Rank,
- MO.SvdVals,
- MO.Inv,
- MO.Tra,
- MO.Svd,
- ]
+ # Square
+ ops += [MO.Sq]
- return []
+ # Reciprocal
+ if info.output.is_nonzero:
+ ops += [MO.Reciprocal]
+
+ # Square Root (Principal)
+ ops += [MO.Sqrt]
+
+ # Inverse Sqrt
+ if info.output.is_nonzero:
+ ops += [MO.InvSqrt]
+
+ # Cos/Sin/Tan/Sinc
+ if info.output.unit == spu.radian:
+ ops += [MO.Cos, MO.Sin, MO.Tan, MO.Sinc]
+
+ # Inverse Cos/Sin/Tan
+ ## -> Presume complex-extensions that aren't limited.
+ if info.output.physical_type is PT.NonPhysical and info.output.unit is None:
+ ops += [MO.Acos, MO.Asin, MO.Atan]
+
+ # By Vector
+ if info.output.shape_len == 1:
+ ops += [MO.Norm2]
+
+ # By Matrix
+ if info.output.shape_len == 2: # noqa: PLR2004
+ if info.output.rows == info.output.cols:
+ ops += [MO.Det]
+
+ # Square Matrix
+ if info.output.rows == info.output.cols:
+ # Det
+ ops += [MO.Det]
+
+ # Diag
+ ops += [MO.Diag]
+
+ # Inv
+ ops += [MO.Inv]
+
+ # Cond
+ ops += [MO.Cond]
+
+ # NormFro
+ ops += [MO.NormFro]
+
+ # Rank
+ ops += [MO.Rank]
+
+ # EigVals
+ ops += [MO.EigVals]
+
+ # SvdVals
+ ops += [MO.EigVals]
+
+ # Tra
+ ops += [MO.Tra]
+
+ # QR
+ ops += [MO.QR_Q, MO.QR_R]
+
+ return ops
####################
- # - Function Properties
+ # - Implementations
####################
@property
def sp_func(self):
+ """Implement the mapping operation for sympy expressions."""
MO = MapOperation
return {
# By Number
@@ -221,6 +263,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
+ MO.Reciprocal: lambda expr: 1 / expr,
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
@@ -240,19 +283,25 @@ class MapOperation(enum.StrEnum):
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
- MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
+ MO.EigVals: lambda expr: sp.ImmutableMatrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
- MO.Qr: lambda expr: expr.QRdecomposition(),
- MO.Chol: lambda expr: expr.cholesky(),
- MO.Svd: lambda expr: expr.singular_value_decomposition(),
+ MO.QR_Q: lambda expr: expr.QRdecomposition()[0],
+ MO.QR_R: lambda expr: expr.QRdecomposition()[1],
+ # MO.Chol: lambda expr: expr.cholesky(),
+ # MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property
- def jax_func(self):
+ def jax_func(
+ self,
+ ) -> typ.Callable[
+ [jtyp.Shaped[jtyp.Array, '...'], int], jtyp.Shaped[jtyp.Array, '...']
+ ]:
+ """Implements the identified mapping using `jax`."""
MO = MapOperation
return {
# By Number
@@ -261,6 +310,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda expr: jnp.sqrt(expr),
+ MO.Reciprocal: lambda expr: 1 / expr,
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda expr: jnp.sin(expr),
@@ -286,80 +336,231 @@ class MapOperation(enum.StrEnum):
MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices
- MO.Qr: lambda expr: jnp.linalg.qr(expr),
- MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
- MO.Svd: lambda expr: jnp.linalg.svd(expr),
+ MO.QR_Q: lambda expr: jnp.linalg.qr(expr)[0],
+ MO.QR_R: lambda expr: jnp.linalg.qr(expr, mode='r'),
+ # MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
+ # MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self]
+ ####################
+ # - Transforms: FlowKind
+ ####################
+ def transform_func(self, func: ct.FuncFlow) -> ct.FuncFlow:
+ """Transform input function according to the current operation and output info characterization."""
+ return func.compose_within(
+ self.jax_func,
+ enclosing_func_output=self.transform_output(func.func_output),
+ supports_jax=True,
+ )
+
def transform_info(self, info: ct.InfoFlow):
+ """Transform the `InfoFlow` characterizing the output."""
+ return info.update(output=self.transform_output(info.output))
+
+ def transform_params(self, params: ct.ParamsFlow):
+ """Transform the incoming function parameters to include output arguments."""
+ return params
+
+ ####################
+ # - Transforms: Symbolic
+ ####################
+ def transform_output(self, sym: sim_symbols.SimSymbol): # noqa: PLR0911
+ """Transform the `SimSymbol` characterizing the output."""
MO = MapOperation
- return {
+ dm = sym.domain
+ match self:
# By Number
- MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
- MO.Sq: lambda: info,
- MO.Sqrt: lambda: info,
- MO.InvSqrt: lambda: info,
- MO.Cos: lambda: info,
- MO.Sin: lambda: info,
- MO.Tan: lambda: info,
- MO.Acos: lambda: info,
- MO.Asin: lambda: info,
- MO.Atan: lambda: info,
- MO.Sinc: lambda: info,
- # By Vector
- MO.Norm2: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- # Interval
- interval_finite_re=(0, sim_symbols.float_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
+ case MO.Real:
+ return sym.update(
+ mathtype=MT.Real,
+ domain=dm.real,
+ )
+ case MO.Imag:
+ return sym.update(
+ mathtype=MT.Real,
+ domain=dm.imag,
+ )
+ case MO.Abs:
+ return sym.update(
+ mathtype=MT.Real,
+ domain=dm.abs,
+ )
+
+ case MO.Sq:
+ return sym.update(
+ domain=dm**2,
+ )
+ case MO.Reciprocal:
+ orig_unit = sym.unit
+ new_unit = 1 / orig_unit if orig_unit is not None else None
+ new_phy_type = PT.from_unit(new_unit, optional=True)
+
+ return sym.update(
+ physical_type=new_phy_type,
+ unit=new_unit,
+ domain=dm.reciprocal,
+ )
+ case MO.Sqrt:
+ ## TODO: Complex -> Real MathType
+ return sym.update(domain=dm ** sp.Rational(1, 2))
+ case MO.InvSqrt:
+ ## TODO: Complex -> Real MathType
+ return sym.update(domain=(dm ** sp.Rational(1, 2)).reciprocal)
+
+ case MO.Cos:
+ return sym.update(
+ physical_type=PT.NonPhysical,
+ unit=None,
+ domain=dm.cos,
+ )
+ case MO.Sin:
+ return sym.update(
+ physical_type=PT.NonPhysical,
+ unit=None,
+ domain=dm.sin,
+ )
+ case MO.Tan:
+ return sym.update(
+ physical_type=PT.NonPhysical,
+ unit=None,
+ domain=dm.tan,
+ )
+
+ case MO.Acos:
+ return sym.update(
+ mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
+ physical_type=PT.Angle,
+ unit=spu.radian,
+ domain=dm.acos,
+ )
+ case MO.Asin:
+ return sym.update(
+ mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
+ physical_type=PT.Angle,
+ unit=spu.radian,
+ domain=dm.asin,
+ )
+ case MO.Atan:
+ return sym.update(
+ mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
+ physical_type=PT.Angle,
+ unit=spu.radian,
+ domain=dm.atan,
+ )
+
+ case MO.Sinc:
+ return sym.update(
+ physical_type=PT.NonPhysical,
+ unit=None,
+ domain=dm.sinc,
+ )
+
+ # By Vector/Covector
+ case MO.Norm2:
+ size = max([sym.rows, sym.cols])
+ return sym.update(
+ mathtype=MT.Real,
+ rows=1,
+ cols=1,
+ domain=(size * dm**2) ** sp.Rational(1, 2),
+ )
+
# By Matrix
- MO.Det: lambda: info.update_output(
- rows=1,
- cols=1,
- ),
- MO.Cond: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- physical_type=spux.PhysicalType.NonPhysical,
- unit=None,
- ),
- MO.NormFro: lambda: info.update_output(
- mathtype=spux.MathType.Real,
- rows=1,
- cols=1,
- # Interval
- interval_finite_re=(0, sim_symbols.float_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- MO.Rank: lambda: info.update_output(
- mathtype=spux.MathType.Integer,
- rows=1,
- cols=1,
- physical_type=spux.PhysicalType.NonPhysical,
- unit=None,
- # Interval
- interval_finite_re=(0, sim_symbols.int_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- # Matrix -> Vector ## TODO: ALL OF THESE
- MO.Diag: lambda: info,
- MO.EigVals: lambda: info,
- MO.SvdVals: lambda: info,
- # Matrix -> Matrix ## TODO: ALL OF THESE
- MO.Inv: lambda: info,
- MO.Tra: lambda: info,
- # Matrix -> Matrices ## TODO: ALL OF THESE
- MO.Qr: lambda: info,
- MO.Chol: lambda: info,
- MO.Svd: lambda: info,
- }[self]()
+ case MO.Det:
+ ## -> NOTE: Determinant only valid for square matrices.
+ size = sym.rows
+ orig_unit = sym.unit
+
+ new_unit = orig_unit**size if orig_unit is not None else None
+ _new_phy_type = PT.from_unit(new_unit, optional=True)
+ new_phy_type = (
+ _new_phy_type if _new_phy_type is not None else PT.NonPhysical
+ )
+
+ return sym.update(
+ physical_type=new_phy_type,
+ unit=new_unit,
+ rows=1,
+ cols=1,
+ domain=(size * dm**2) ** sp.Rational(1, 2),
+ )
+
+ case MO.Cond:
+ return sym.update(
+ mathtype=MT.Real,
+ physical_type=PT.NonPhysical,
+ unit=None,
+ rows=1,
+ cols=1,
+ domain=spux.BlessedSet(sp.Interval(1, sp.oo)),
+ )
+
+ case MO.NormFro:
+ return sym.update(
+ mathtype=MT.Real,
+ rows=1,
+ cols=1,
+ domain=(sym.rows * sym.cols * abs(dm) ** 2) ** sp.Rational(1, 2),
+ )
+
+ case MO.Rank:
+ return sym.update(
+ mathtype=MT.Integer,
+ physical_type=PT.NonPhysical,
+ unit=None,
+ rows=1,
+ cols=1,
+ domain=spux.BlessedSet(sp.Range(0, min([sym.rows, sym.cols]) + 1)),
+ )
+
+ case MO.Diag:
+ return sym.update(cols=1)
+
+ case MO.EigVals:
+ ## TODO: Gershgorin circle theorem?
+ return spux.BlessedSet(sp.Complexes)
+
+ case MO.SvdVals:
+ ## TODO: Domain bound on singular values?
+ ## -- We might also consider a 'nonzero singvals' operation.
+ ## -- Since singular values can be zero just fine.
+ return sym.update(
+ mathtype=MT.Real,
+ cols=1,
+ domain=spux.BlessedSet(sp.Interval(0, sp.oo)),
+ )
+
+ case MO.Inv:
+ ## -> Defined: Square non-singular matrices.
+ orig_unit = sym.unit
+ new_unit = 1 / orig_unit if orig_unit is not None else None
+ new_phy_type = PT.from_unit(new_unit, optional=True)
+
+ return sym.update(
+ physical_type=new_phy_type,
+ unit=new_unit,
+ domain=sym.mathtype.symbolic_set,
+ )
+
+ case MO.Tra:
+ return sym.update(
+ rows=sym.cols,
+ cols=sym.rows,
+ )
+
+ case MO.QR_Q:
+ return sym.update(
+ mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
+ physical_type=PT.NonPhysical,
+ unit=None,
+ cols=min([sym.rows, sym.cols]),
+ domain=(
+ spux.BlessedSet(spux.ComplexRegion(sp.Interval(-1, 1) ** 2))
+ if sym.mathtype is MT.Complex
+ else spux.BlessedSet(sp.Interval(-1, 1))
+ ),
+ )
+
+ case MO.QR_R:
+ return sym
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py
index cf3d228..a9d91f8 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/operate.py
@@ -14,12 +14,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements transform operations for the `MapNode`."""
+
import enum
+import functools
import typing as typ
import jax.numpy as jnp
import sympy as sp
-import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import logger, sim_symbols
@@ -29,6 +31,113 @@ from .. import contracts as ct
log = logger.get(__name__)
+MT = spux.MathType
+PT = spux.PhysicalType
+
+
+# @functools.lru_cache(maxsize=1024)
+# def expand_shapes(
+# shape_l: tuple[int, ...], shape_r: tuple[int, ...]
+# ) -> tuple[tuple[int, ...], tuple[int, ...]]:
+# """Transform each shape to two new shapes, whose lengths are identical, and for which operations between are well-defined, but which occupies the same amount of memory."""
+# axes = dict(
+# reversed(
+# list(
+# itertools.zip_longest(reversed(shape_l), reversed(shape_r), fillvalue=1)
+# )
+# )
+# )
+#
+# return (tuple(axes.keys()), tuple(axes.values()))
+#
+#
+# @functools.lru_cache(maxsize=1024)
+# def broadcast_shape(
+# expanded_shape_l: tuple[int, ...], expanded_shape_r: tuple[int, ...]
+# ) -> tuple[int, ...] | None:
+# """Deduce a common shape that an object of both expanded shapes can be broadcast to."""
+# new_shape = []
+# for ax_l, ax_r in itertools.zip_longest(
+# expanded_shape_l, expanded_shape_r, fillvalue=1
+# ):
+# if ax_l == 1 or ax_r == 1 or ax_l == ax_r: # noqa: PLR1714
+# new_shape.append(max([ax_l, ax_r]))
+# else:
+# return None
+#
+# return tuple(new_shape)
+#
+#
+# def broadcast_to_shape(
+# M: sp.NDimArray, compatible_shape: tuple[int, ...]
+# ) -> spux.SympyType:
+# """Conform an array with expanded shape to the given shape, expanding any axes that need expanding."""
+# L = M
+#
+# incremental_shape = ()
+# for orig_ax, new_ax in reversed(zip(M.shape, compatible_shape, strict=True)):
+# incremental_shape = (new_ax, *incremental_shape)
+# if orig_ax == 1 and new_ax > 1:
+# _L = sp.flatten(L) if L.shape == () else L.tolist()
+#
+# L = sp.ImmutableDenseNDimArray(_L * new_ax).reshape(*compatible_shape)
+#
+# return L
+#
+#
+# def sp_operation(op, lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType | None:
+# if not isinstance(lhs, sp.MatrixBase | sp.NDimArray) and not isinstance(
+# lhs, sp.MatrixBase | sp.NDimArray
+# ):
+# return op(lhs, rhs)
+#
+# # Deduce Expanded L/R Arrays
+# ## -> This conforms the shape of both operands to broadcastable shape.
+# ## -> The actual memory usage from doing this remains identical.
+# _L = sp.ImmutableDenseNDimArray(lhs)
+# _R = sp.ImmutableDenseNDimArray(rhs)
+# expanded_shape_l, expanded_shape_r = expand_shapes(_L.shape, _R.shape)
+# _L = _L.reshape(*expanded_shape_l)
+# _R = _R.reshape(*expanded_shape_r)
+#
+# # Broadcast Expanded L/R Arrays
+# ## -> Expanded dimensions will be conformed to the max of each.
+# ## -> This conforms the shape of both operands to broadcastable shape.
+# broadcasted_shape = broadcast_to_shape(expanded_shape_l, expanded_shape_r)
+# if broadcasted_shape is None:
+# return None
+#
+# L = broadcast_to_shape(_L, broadcasted_shape)
+# R = broadcast_to_shape(_R, broadcasted_shape)
+#
+# # Run Elementwise Operation
+# ## -> Elementwise operations can now cleanly run between both operands.
+# output = op(L, R)
+# if output.shape in [1, 2]:
+# return sp.ImmutableMatrix(output.tomatrix())
+# return output
+#
+#
+# def hadamard_product(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType | None:
+# match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
+# case (False, False):
+# return lhs * rhs
+#
+# case (True, False):
+# return lhs.applyfunc(lambda el: el * rhs)
+#
+# case (False, True):
+# return rhs.applyfunc(lambda el: lhs * el)
+#
+# case (True, True) if lhs.shape == rhs.shape:
+# common_shape = lhs.shape
+# return sp.ImmutableMatrix(
+# *common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
+# )
+#
+# msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
+# raise ValueError(msg)
+
def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
"""Implement the Hadamard Power.
@@ -37,8 +146,7 @@ def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
"""
match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
case (False, False):
- msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}"
- raise ValueError(msg)
+ return lhs**rhs
case (True, False):
return lhs.applyfunc(lambda el: el**rhs)
@@ -52,9 +160,8 @@ def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
*common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
)
- case _:
- msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
- raise ValueError(msg)
+ msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
+ raise ValueError(msg)
class BinaryOperation(enum.StrEnum):
@@ -74,7 +181,6 @@ class BinaryOperation(enum.StrEnum):
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system.
- VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
"""
@@ -100,9 +206,6 @@ class BinaryOperation(enum.StrEnum):
LinSolve = enum.auto()
LsqSolve = enum.auto()
- # Vector | Matrix
- VecMatOuter = enum.auto()
-
# Matrix | Matrix
MatMatDot = enum.auto()
@@ -132,17 +235,20 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: '𝐋 ∖ 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱−𝐫∥₂',
- # Vector | Matrix
- BO.VecMatOuter: '𝐋 ⊗ 𝐫',
# Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑',
}[value]
@staticmethod
- def to_icon(value: typ.Self) -> str:
+ def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
+ @functools.cached_property
+ def name(self) -> str:
+ """No icons."""
+ return BinaryOperation.to_name(self)
+
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
BO = BinaryOperation
@@ -154,8 +260,9 @@ class BinaryOperation(enum.StrEnum):
i,
)
+ @staticmethod
def bl_enum_elements(
- self, info_l: ct.InfoFlow, info_r: ct.InfoFlow
+ info_l: ct.InfoFlow, info_r: ct.InfoFlow
) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
@@ -163,14 +270,14 @@ class BinaryOperation(enum.StrEnum):
"""
return [
operation.bl_enum_element(i)
- for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r))
+ for i, operation in enumerate(BinaryOperation.from_infos(info_l, info_r))
]
####################
# - Ops from Shape
####################
@staticmethod
- def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]:
+ def from_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]: # noqa: C901, PLR0912
"""Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation
ops = []
@@ -190,14 +297,15 @@ class BinaryOperation(enum.StrEnum):
ops += [BO.Mul, BO.Div]
case (ordl, ordr, True) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
- case (ordl, ordr, True) if ordl > 0 and ordr > 0:
+ case (ordl, ordr, _) if ordl > 0 and ordr > 0:
+ ## TODO: _ is not correct
ops += [BO.HadamMul, BO.HadamDiv]
case (ordl, ordr, False) if ordl == 0 and ordr == 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr == 0:
ops += [BO.Mul]
- case (ordl, ordr, True) if ordl == 0 and ordr > 0:
+ case (ordl, ordr, False) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr > 0:
ops += [BO.HadamMul]
@@ -212,7 +320,7 @@ class BinaryOperation(enum.StrEnum):
case (ordl, ordr, _) if ordl == 0 and ordr == 0:
ops += [BO.Pow]
- case (ordl, ordr, spux.MathType.Integer) if (
+ case (ordl, ordr, MT.Integer) if (
ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols
):
ops += [BO.Pow, BO.HadamPow]
@@ -220,29 +328,30 @@ class BinaryOperation(enum.StrEnum):
case _:
ops += [BO.HadamPow]
+ # Atan2
+ if (
+ (
+ info_l.output.mathtype is not MT.Complex
+ and info_r.output.mathtype is not MT.Complex
+ )
+ and (
+ info_l.output.physical_type is PT.Length
+ and info_r.output.physical_type is PT.Length
+ )
+ or (
+ info_l.output.physical_type is PT.NonPhysical
+ and info_l.output.unit is None
+ and info_r.output.physical_type is PT.NonPhysical
+ and info_r.output.unit is None
+ )
+ ):
+ ops += [BO.Atan2]
+
# Operations by-Output Length
match (
info_l.output.shape_len,
info_r.output.shape_len,
):
- # Number | Number
- case (0, 0) if info_l.is_scalar and info_r.is_scalar:
- # atan2: PhysicalType Must Both be Length | NonPhysical
- ## -> atan2() produces radians from Cartesian coordinates.
- ## -> This wouldn't make sense on non-Length / non-Unitless.
- if (
- info_l.output.physical_type is spux.PhysicalType.Length
- and info_r.output.physical_type is spux.PhysicalType.Length
- ) or (
- info_l.output.physical_type is spux.PhysicalType.NonPhysical
- and info_l.output.unit is None
- and info_r.output.physical_type is spux.PhysicalType.NonPhysical
- and info_r.output.unit is None
- ):
- ops += [BO.Atan2]
-
- return ops
-
# Vector | Vector
case (1, 1) if info_l.compare_dims_identical(info_r):
outl = info_l.output
@@ -274,19 +383,11 @@ class BinaryOperation(enum.StrEnum):
## -> Works great element-wise.
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
- if (outl.rows == 3 and outr.rows == 3) or (
- outl.cols == 3 and outl.cols == 3
+ if (outl.rows == 3 and outr.rows == 3) or ( # noqa: PLR2004
+ outl.cols == 3 and outl.cols == 3 # noqa: PLR2004
):
ops += [BO.Cross]
- # Vector | Matrix
- ## -> We can't do per-element outer product.
- ## -> However, it's still super useful on its own.
- case (1, 2) if info_l.compare_dims_identical(
- info_r
- ) and info_l.order == 1 and info_r.order == 2:
- ops += [BO.VecMatOuter]
-
# Matrix | Vector
case (2, 1) if info_l.compare_dims_identical(info_r):
# Mat-Vec Dot: Enforce RHS Column Vector
@@ -309,17 +410,20 @@ class BinaryOperation(enum.StrEnum):
"""Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
BO = BinaryOperation
- ## TODO: Make this compatible with sp.Matrix inputs
+ ## BODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
- BO.Pow: lambda exprs: exprs[0] ** exprs[1],
+ BO.Pow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
- BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
+ BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
+ exprs[1].applyfunc(lambda el: 1 / el)
+ ),
+ BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
@@ -328,27 +432,27 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
- # Vector | Matrix
- BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
- def unit_func(self):
+ def scalar_sp_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
- ## TODO: Make this compatible with sp.Matrix inputs
+ ## BODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
- BO.Add: BO.Add.sp_func,
- BO.Sub: BO.Sub.sp_func,
+ BO.Add: lambda exprs: exprs[0],
+ BO.Sub: lambda exprs: exprs[0],
BO.HadamMul: BO.Mul.sp_func,
+ BO.HadamDiv: BO.Div.sp_func,
+ BO.HadamPow: BO.Pow.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
@@ -360,8 +464,6 @@ class BinaryOperation(enum.StrEnum):
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
- # Vector | Matrix
- BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self]
@@ -369,7 +471,7 @@ class BinaryOperation(enum.StrEnum):
@property
def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
- ## TODO: Scale the units of one side to the other.
+ ## BODO: Scale the units of one side to the other.
BO = BinaryOperation
return {
@@ -380,11 +482,9 @@ class BinaryOperation(enum.StrEnum):
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
- BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
- BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
- exprs[1].applyfunc(lambda el: 1 / el)
- ),
- BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
+ BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
+ BO.HadamDiv: lambda exprs: exprs[0] / exprs[1],
+ BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]),
@@ -393,8 +493,6 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
- # Vector | Matrix
- BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self]
@@ -415,62 +513,172 @@ class BinaryOperation(enum.StrEnum):
else:
norm_func_r = func_r
- return (func_l, norm_func_r).compose_within(
+ return (func_l | norm_func_r).compose_within(
self.jax_func,
enclosing_func_output=self.transform_outputs(
- func_l.func_output, norm_func_r.func_output
+ func_l.func_output, func_r.func_output
),
supports_jax=True,
)
- def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
- """Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
+ def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> ct.InfoFlow:
+ """Transform the `InfoFlow` characterizing the output."""
+ if len(info_l.dims) == 0:
+ dims = info_r.dims
+ elif len(info_r.dims) == 0:
+ dims = info_l.dims
+ else:
+ dims = info_l.dims
- Warnings:
- `self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
-
- If not, bad things will happen.
- """
- return info_l.operate_output(
- info_r,
- lambda a, b: self.sp_func([a, b]),
- lambda a, b: self.unit_func([a, b]),
+ return ct.InfoFlow(
+ dims=dims,
+ output=self.transform_outputs(info_l.output, info_r.output),
+ pinned_values=info_l.pinned_values | info_r.pinned_values,
)
+ def transform_params(
+ self, params_l: ct.ParamsFlow, params_r: ct.ParamsFlow
+ ) -> ct.ParamsFlow:
+ """Aggregate the incoming function parameters for the output."""
+ return params_l | params_r
+
####################
- # - InfoFlow Transform
+ # - Other Transforms
####################
def transform_outputs(
- self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol
+ self, sym_l: sim_symbols.SimSymbol, sym_r: sim_symbols.SimSymbol
) -> sim_symbols.SimSymbol:
- # TO = TransformOperation
- return None
- # match self:
- # # Number | Number
- # case TO.Mul:
- # return
- # case TO.Div:
- # case TO.Pow:
+ BO = BinaryOperation
- # # Elements | Elements
- # Add = enum.auto()
- # Sub = enum.auto()
- # HadamMul = enum.auto()
- # HadamPow = enum.auto()
- # HadamDiv = enum.auto()
- # Atan2 = enum.auto()
+ if sym_l.sym_name == sym_r.sym_name:
+ name = sym_l.sym_name
+ else:
+ name = sim_symbols.SimSymbolName.Expr
- # # Vector | Vector
- # VecVecDot = enum.auto()
- # Cross = enum.auto()
- # VecVecOuter = enum.auto()
+ dm_l = sym_l.domain
+ dm_r = sym_r.domain
+ match self:
+ case BO.Mul | BO.Div | BO.Pow | BO.HadamDiv:
+ # dm = self.scalar_sp_func([dm_l, dm_r])
+ unit_factor = self.scalar_sp_func(
+ [sym_l.unit_factor, sym_r.unit_factor]
+ )
+ unit = unit_factor if spux.uses_units(unit_factor) else None
+ physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
- # # Matrix | Vector
- # LinSolve = enum.auto()
- # LsqSolve = enum.auto()
+ mathtype = MT.combine(
+ MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
+ )
+ return sim_symbols.SimSymbol(
+ sym_name=name,
+ mathtype=mathtype,
+ physical_type=physical_type,
+ unit=unit,
+ rows=max([sym_l.rows, sym_r.rows]),
+ cols=max([sym_l.cols, sym_r.cols]),
+ depths=tuple(
+ [
+ max([dp_l, dp_r])
+ for dp_l, dp_r in zip(
+ sym_l.depths, sym_r.depths, strict=True
+ )
+ ]
+ ),
+ domain=spux.BlessedSet(mathtype.symbolic_set),
+ )
- # # Vector | Matrix
- # VecMatOuter = enum.auto()
+ case BO.Add | BO.Sub:
+ fac_r_unit_to_l_unit = sp.S(spux.scaling_factor(sym_l.unit, sym_r.unit))
- # # Matrix | Matrix
- # MatMatDot = enum.auto()
+ dm = self.scalar_sp_func([dm_l, dm_r * fac_r_unit_to_l_unit])
+ unit_factor = self.scalar_sp_func(
+ [sym_l.unit_factor, sym_r.unit_factor]
+ )
+ unit = unit_factor if spux.uses_units(unit_factor) else None
+ physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
+
+ return sym_l.update(
+ sym_name=name,
+ mathtype=MT.from_symbolic_set(dm.bset),
+ physical_type=physical_type,
+ unit=None if unit_factor == 1 else unit_factor,
+ domain=dm,
+ )
+
+ case BO.HadamMul | BO.HadamPow:
+ # fac_r_unit_to_l_unit = sp.S(spux.scaling_factor(sym_l.unit, sym_r.unit))
+
+ mathtype = MT.combine(
+ MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
+ )
+ # dm = self.scalar_sp_func([dm_l, dm_r * fac_r_unit_to_l_unit])
+ unit_factor = self.scalar_sp_func(
+ [sym_l.unit_factor, sym_r.unit_factor]
+ )
+ unit = unit_factor if spux.uses_units(unit_factor) else None
+ physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
+
+ return sym_l.update(
+ sym_name=name,
+ mathtype=mathtype,
+ physical_type=physical_type,
+ unit=None if unit_factor == 1 else unit_factor,
+ domain=spux.BlessedSet(mathtype.symbolic_set),
+ )
+
+ case BO.Atan2:
+ dm = dm_l.atan2(dm_r)
+
+ return sym_l.update(
+ sym_name=name,
+ mathtype=MT.from_symbolic_set(dm.bset),
+ physical_type=PT.Angle,
+ unit=spu.radian,
+ domain=dm,
+ )
+
+ case BO.VecVecDot:
+ _dm = dm_l * dm_r
+ dm = _dm + _dm
+
+ return sym_l.update(
+ sym_name=name,
+ domain=dm,
+ rows=1,
+ cols=1,
+ )
+
+ case BO.Cross:
+ _dm = dm_l * dm_r
+ dm = _dm + _dm
+
+ return sym_l.update(
+ sym_name=name,
+ domain=dm,
+ )
+
+ case BO.VecVecOuter:
+ dm = dm_l * dm_r
+
+ return sym_l.update(
+ sym_name=name,
+ domain=dm,
+ rows=max([sym_l.rows, sym_r.rows]),
+ cols=max([sym_l.cols, sym_r.cols]),
+ )
+
+ case BO.LinSolve | BO.LsqSolve:
+ mathtype = MT.combine(
+ MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
+ ).symbolic_set
+ dm = spux.BlessedSet(mathtype.symbolic_set)
+
+ return sym_r.update(mathtype=mathtype, domain=dm)
+
+ case BO.MatMatDot:
+ mathtype = MT.combine(
+ MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
+ ).symbolic_set
+ dm = spux.BlessedSet(mathtype.symbolic_set)
+
+ return sym_r.update(mathtype=mathtype, domain=dm)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py
index 4d26985..49459df 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/reduce.py
@@ -18,6 +18,7 @@ import enum
import typing as typ
import jax.numpy as jnp
+import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
@@ -27,8 +28,30 @@ from .. import contracts as ct
log = logger.get(__name__)
+MT = spux.MathType
+PT = spux.PhysicalType
+
class ReduceOperation(enum.StrEnum):
+ """Valid operations for the `ReduceMathNode`.
+
+ Attributes:
+ Count: The number of discrete elements along an axis.
+ Mean: The average along an axis.
+ Std: The standard deviation along an axis.
+ Var: The variance along an axis.
+ Min: The minimum value along an axis.
+ Q25: The `25%` quantile along an axis.
+ Medium: The `50%` quantile along an axis.
+ Q75: The `75%` quantile along an axis.
+ Max: The `75%` quantile along an axis.
+ P2P: The peak-to-peak range along an axis.
+ Z15ToZ15: The range between z-scores of 1.5 in each direction.
+ Z30ToZ30: The range between z-scores of 3.0 in each direction.
+ Sum: The sum along an axis.
+ Prod: The product along an axis.
+ """
+
# Summary
Count = enum.auto()
@@ -37,15 +60,15 @@ class ReduceOperation(enum.StrEnum):
Std = enum.auto()
Var = enum.auto()
- StdErr = enum.auto()
-
Min = enum.auto()
Q25 = enum.auto()
Median = enum.auto()
Q75 = enum.auto()
Max = enum.auto()
- Mode = enum.auto()
+ P2P = enum.auto()
+ Z15ToZ15 = enum.auto()
+ Z30ToZ30 = enum.auto()
# Reductions
Sum = enum.auto()
@@ -61,22 +84,27 @@ class ReduceOperation(enum.StrEnum):
return {
# Summary
RO.Count: '# [a]',
- RO.Mode: 'mode [a]',
# Statistics
RO.Mean: 'μ [a]',
RO.Std: 'σ [a]',
RO.Var: 'σ² [a]',
- RO.StdErr: 'stderr [a]',
RO.Min: 'min [a]',
RO.Q25: 'q₂₅ [a]',
RO.Median: 'median [a]',
RO.Q75: 'q₇₅ [a]',
- RO.Min: 'max [a]',
+ RO.Max: 'max [a]',
+ RO.P2P: 'p2p [a]',
+ RO.Z15ToZ15: 'σ[1.5] [a]',
+ RO.Z30ToZ30: 'σ[3.0] [a]',
# Reductions
RO.Sum: 'sum [a]',
RO.Prod: 'prod [a]',
}[value]
+ @property
+ def name(self) -> str:
+ return ReduceOperation.to_name(self)
+
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
@@ -93,24 +121,165 @@ class ReduceOperation(enum.StrEnum):
i,
)
+ @staticmethod
+ def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
+ """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
+
+ Returns a `bpy.props.EnumProperty.items`-compatible list.
+ """
+ return [
+ operation.bl_enum_element(i)
+ for i, operation in enumerate(ReduceOperation.from_info(info))
+ ]
+
####################
# - Derivation
####################
@staticmethod
def from_info(info: ct.InfoFlow) -> list[typ.Self]:
"""Derive valid reduction operations from the `InfoFlow` of the operand."""
- pass
+ RO = ReduceOperation
+ ops = []
+
+ if info.dims and any(
+ info.has_idx_discrete(dim) or info.has_idx_labels(dim) for dim in info.dims
+ ):
+ # Summary
+ ops += [RO.Count]
+
+ # Statistics
+ ops += [
+ RO.Mean,
+ RO.Std,
+ RO.Var,
+ RO.Min,
+ RO.Q25,
+ RO.Median,
+ RO.Q75,
+ RO.Max,
+ RO.P2P,
+ RO.Z15ToZ15,
+ RO.Z30ToZ30,
+ ]
+
+ # Reductions
+ ops += [RO.Sum, RO.Prod]
+
+ ## I know, they can be combined.
+ ## But they may one day need more checks.
+
+ return ops
+
+ def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
+ """Valid dimensions that can be reduced."""
+ return [
+ dim
+ for dim in info.dims
+ if info.has_idx_discrete(dim) or info.has_idx_labels(dim)
+ ]
####################
# - Composable Functions
####################
@property
- def jax_func(self):
+ def jax_func(
+ self,
+ ) -> typ.Callable[
+ [jtyp.Shaped[jtyp.Array, '...'], int], jtyp.Shaped[jtyp.Array, '...']
+ ]:
+ """Implements the identified reduction using `jax`."""
RO = ReduceOperation
- return {}[self]
+ return {
+ # Summary
+ RO.Count: lambda el, axis: el.shape[axis],
+ # RO.Mode: lambda el, axis: jsc.stats.mode(el, axis=axis).
+ # Statistics
+ RO.Mean: lambda el, axis: jnp.mean(el, axis=axis),
+ RO.Std: lambda el, axis: jnp.std(el, axis=axis),
+ RO.Var: lambda el, axis: jnp.var(el, axis=axis),
+ RO.Min: lambda el, axis: jnp.min(el, axis=axis),
+ RO.Q25: lambda el, axis: jnp.quantile(el, 0.25, axis=axis),
+ RO.Median: lambda el, axis: jnp.median(el, axis=axis),
+ RO.Q75: lambda el, axis: jnp.quantile(el, 0.75, axis=axis),
+ RO.Max: lambda el, axis: jnp.max(el, axis=axis),
+ RO.P2P: lambda el, axis: jnp.ptp(el, axis=axis),
+ RO.Z15ToZ15: lambda el, axis: 2 * (3 / 2) * jnp.std(el, axis=axis),
+ RO.Z30ToZ30: lambda el, axis: 2 * 3 * jnp.std(el, axis=axis),
+ # Statistics
+ RO.Sum: lambda el, axis: jnp.sum(el, axis=axis),
+ RO.Prod: lambda el, axis: jnp.prod(el, axis=axis),
+ }[self]
####################
# - Transforms
####################
- def transform_info(self, info: ct.InfoFlow):
- pass
+ def transform_func(self, func: ct.InfoFlow):
+ """Transform the lazy `FuncFlow` to reduce the input."""
+ return func.compose_within(
+ self.jax_func,
+ enclosing_func_args=(sim_symbols.idx(None),),
+ enclosing_func_output=self.transform_output(func.func_output),
+ supports_jax=True,
+ )
+
+ def transform_info(self, info: ct.InfoFlow, dim: sim_symbols.SimSymbol):
+ """Transform the characterizing `InfoFlow` of the reduced operand."""
+ return info.delete_dim(dim).update(output=self.transform_output(info.output))
+
+ def transform_params(self, params: ct.ParamsFlow, axis: int) -> None:
+ """Transform the characterizing `InfoFlow` of the reduced operand."""
+ return params.compose_within(
+ enclosing_func_args=(sp.Integer(axis),),
+ )
+
+ def transform_output(self, sym: sim_symbols.SimSymbol) -> sim_symbols.SimSymbol:
+ """Transform the domain of the output symbol.
+
+ Parameters:
+ dom: Symbolic set representing the original output symbol's domain.
+ info: Characterization of the original expression.
+ dim: Dimension symbol being reduced away.
+
+ """
+ RO = ReduceOperation
+
+ match self:
+ # Summary
+ case RO.Count:
+ return sym.update(
+ sym_name=sim_symbols.SimSymbolName.Count,
+ mathtype=MT.Integer,
+ physical_type=PT.NonPhysical,
+ unit=None,
+ rows=1,
+ cols=1,
+ domain=spux.BlessedSet(sp.Naturals0),
+ )
+
+ # Statistics
+ case (
+ RO.Mean
+ | RO.Std
+ | RO.Var
+ | RO.Min
+ | RO.Q25
+ | RO.Median
+ | RO.Q75
+ | RO.Max
+ | RO.P2P
+ | RO.Z15ToZ15
+ | RO.Z30ToZ30
+ ):
+ ## -> Stats are enclosed by the original domain.
+ return sym
+
+ # Reductions
+ case RO.Sum:
+ return sym.update(
+ domain=spux.BlessedSet(sym.mathtype.symbolic_set),
+ )
+
+ case RO.Prod:
+ return sym.update(
+ domain=spux.BlessedSet(sym.mathtype.symbolic_set),
+ )
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py
index 499c082..58bc095 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/math_system/transform.py
@@ -19,6 +19,7 @@ import typing as typ
import jax.numpy as jnp
import jaxtyping as jtyp
+import sympy as sp
from blender_maxwell.utils import logger, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@@ -66,14 +67,12 @@ class TransformOperation(enum.StrEnum):
FT1D = enum.auto()
InvFT1D = enum.auto()
- # TODO: Affine
- ## TODO
-
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
+ """A human-readable UI-oriented name."""
TO = TransformOperation
return {
# Covariant Transform
@@ -99,9 +98,11 @@ class TransformOperation(enum.StrEnum):
@staticmethod
def to_icon(_: typ.Self) -> str:
+ """No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ """Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
TO = TransformOperation
return (
str(self),
@@ -111,6 +112,17 @@ class TransformOperation(enum.StrEnum):
i,
)
+ @staticmethod
+ def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
+ """Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
+
+ Returns a `bpy.props.EnumProperty.items`-compatible list.
+ """
+ return [
+ operation.bl_enum_element(i)
+ for i, operation in enumerate(TransformOperation.by_info(info))
+ ]
+
####################
# - Methods
####################
@@ -300,14 +312,16 @@ class TransformOperation(enum.StrEnum):
physical_type=physical_type,
unit=unit,
),
- ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
+ ct.RangeFlow.try_from_array(
+ ct.ArrayFlow(jax_bytes=data_col, unit=unit)
+ ),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
),
- TO.DimToVec: lambda: info.fold_last_input(),
- TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
+ TO.DimToVec: lambda: info.fold_last_input,
+ TO.DimsToMat: lambda: info.fold_last_input.fold_last_input,
# Fourier
TO.FT1D: lambda: info.replace_dim(
dim,
@@ -333,4 +347,38 @@ class TransformOperation(enum.StrEnum):
info.dims[dim].bound_inv_fourier_transform,
],
),
- }[self]()
+ }[self]().update_output(
+ domain=self.transform_output_domain(info.output.domain, info, dim)
+ )
+
+ def transform_output_domain(
+ self, dom: spux.BlessedSet, info: ct.InfoFlow, dim: sim_symbols.SimSymbol | None
+ ) -> sp.Set:
+ """Transform the domain of the output symbol.
+
+ Parameters:
+ dom: Symbolic set representing the original output symbol's domain.
+ info: Characterization of the original expression.
+ dim: Dimension symbol being reduced away.
+
+ """
+ TO = TransformOperation
+ match self:
+ # Summary
+ case (
+ TO.FreqToVacWL
+ | TO.VacWLToFreq
+ | TO.ConvertIdxUnit
+ | TO.SetIdxUnit
+ | TO.FirstColToFirstIdx
+ | TO.DimToVec
+ | TO.DimsToMat
+ | TO.IntDimToComplex
+ ):
+ return dom
+
+ case TO.FT1D if dim is not None:
+ return info[dim].bound_fourier_transform.symbolic_set
+
+ case TO.InvFT1D if dim is not None:
+ return info[dim].bound_inv_fourier_transform.symbolic_set
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py
index 263d26b..a99787d 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py
@@ -16,35 +16,38 @@
from . import (
analysis,
- bounds,
inputs,
mediums,
monitors,
outputs,
simulations,
+ solvers,
sources,
structures,
+ utilities,
)
BL_REGISTER = [
*analysis.BL_REGISTER,
+ *utilities.BL_REGISTER,
*inputs.BL_REGISTER,
+ *solvers.BL_REGISTER,
*outputs.BL_REGISTER,
*sources.BL_REGISTER,
*mediums.BL_REGISTER,
*structures.BL_REGISTER,
- *bounds.BL_REGISTER,
*monitors.BL_REGISTER,
*simulations.BL_REGISTER,
]
BL_NODES = {
**analysis.BL_NODES,
+ **utilities.BL_NODES,
**inputs.BL_NODES,
+ **solvers.BL_NODES,
**outputs.BL_NODES,
**sources.BL_NODES,
**mediums.BL_NODES,
**structures.BL_NODES,
- **bounds.BL_NODES,
**monitors.BL_NODES,
**simulations.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
index a2f54db..67247cd 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/extract_data.py
@@ -24,9 +24,9 @@ import bpy
import jax.numpy as jnp
import sympy.physics.units as spu
import tidy3d as td
+import xarray
from blender_maxwell.utils import bl_cache, logger, sim_symbols
-from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
@@ -36,33 +36,44 @@ log = logger.get(__name__)
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
+RealizedSymsVals: typ.TypeAlias = tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...]
+SimDataArray: typ.TypeAlias = dict[RealizedSymsVals, td.SimulationData]
+SimDataArrayInfo: typ.TypeAlias = dict[RealizedSymsVals, typ.Any]
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
####################
-# - Monitor Label Arrays
+# - Monitor Labelling
####################
-def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
+def valid_monitor_attrs(
+ example_sim_data: td.SimulationData, monitor_name: str
+) -> tuple[str, ...]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
"""
- monitor_data = sim_data.monitor_data[monitor_name]
- monitor_type = monitor_data.type
+ monitor_data = example_sim_data.monitor_data[monitor_name]
+ monitor_type = monitor_data.type.removesuffix('Data')
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
- return [
- field_component
- for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
- if getattr(monitor_data, field_component, None) is not None
- ]
+ return tuple(
+ [
+ field_component
+ for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
+ if getattr(monitor_data, field_component, None) is not None
+ ]
+ )
case 'Permittivity':
- return ['eps_xx', 'eps_yy', 'eps_zz']
+ return ('eps_xx', 'eps_yy', 'eps_zz')
case 'Flux' | 'FluxTime':
- return ['flux']
+ return ('flux',)
case (
'FieldProjectionAngle'
@@ -70,140 +81,246 @@ def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[
| 'FieldProjectionKSpace'
| 'Diffraction'
):
- return [
+ return (
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
- ]
-
-
-def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
- """Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
- xarr = getattr(monitor_data, monitor_attr, None)
- if xarr is None:
- return None
-
- def mk_idx_array(axis: str) -> ct.RangeFlow | ct.ArrayFlow:
- return ct.RangeFlow.try_from_array(
- ct.ArrayFlow(
- values=xarr.get_index(axis).values,
- unit=symbols[axis].unit,
- is_sorted=True,
)
+
+ raise TypeError
+
+
+####################
+# - Extract InfoFlow
+####################
+MONITOR_SYMBOLS: dict[str, sim_symbols.SimSymbol] = {
+ # Field Label
+ 'EH*': sim_symbols.sim_axis_idx(None),
+ # Cartesian
+ 'x': sim_symbols.space_x(spu.micrometer),
+ 'y': sim_symbols.space_y(spu.micrometer),
+ 'z': sim_symbols.space_z(spu.micrometer),
+ # Spherical
+ 'r': sim_symbols.ang_r(spu.micrometer),
+ 'theta': sim_symbols.ang_theta(spu.radian),
+ 'phi': sim_symbols.ang_phi(spu.radian),
+ # Freq|Time
+ 'f': sim_symbols.freq(spu.hertz),
+ 't': sim_symbols.t(spu.second),
+ # Power Flux
+ 'flux': sim_symbols.flux(spu.watt),
+ # Wavevector
+ 'ux': sim_symbols.dir_x(spu.watt),
+ 'uy': sim_symbols.dir_y(spu.watt),
+ # Diffraction Orders
+ 'orders_x': sim_symbols.diff_order_x(None),
+ 'orders_y': sim_symbols.diff_order_y(None),
+ # Cartesian Fields
+ 'field': sim_symbols.field_e(spu.volt / spu.micrometer), ## TODO: H???
+ 'field_e': sim_symbols.field_e(spu.volt / spu.micrometer),
+ 'field_h': sim_symbols.field_h(spu.ampere / spu.micrometer),
+}
+
+
+def _mk_idx_array(xarr: xarray.DataArray, axis: str) -> ct.RangeFlow | ct.ArrayFlow:
+ return ct.RangeFlow.try_from_array(
+ ct.ArrayFlow(
+ jax_bytes=xarr.get_index(axis).values,
+ unit=MONITOR_SYMBOLS[axis].unit,
+ is_sorted=True,
)
+ )
- # Compute InfoFlow from XArray
- symbols = {
- # Cartesian
- 'x': sim_symbols.space_x(spu.micrometer),
- 'y': sim_symbols.space_y(spu.micrometer),
- 'z': sim_symbols.space_z(spu.micrometer),
- # Spherical
- 'r': sim_symbols.ang_r(spu.micrometer),
- 'theta': sim_symbols.ang_theta(spu.radian),
- 'phi': sim_symbols.ang_phi(spu.radian),
- # Freq|Time
- 'f': sim_symbols.freq(spu.hertz),
- 't': sim_symbols.t(spu.second),
- # Power Flux
- 'flux': sim_symbols.flux(spu.watt),
- # Cartesian Fields
- 'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
- 'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
- 'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
- 'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
- 'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
- 'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
- # Spherical Fields
- 'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
- 'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
- 'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
- 'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
- 'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
- 'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
- # Wavevector
- 'ux': sim_symbols.dir_x(spu.watt),
- 'uy': sim_symbols.dir_y(spu.watt),
- # Diffraction Orders
- 'orders_x': sim_symbols.diff_order_x(None),
- 'orders_y': sim_symbols.diff_order_y(None),
- }
- match monitor_data.type:
+def output_symbol_by_type(monitor_type: str) -> sim_symbols.SimSymbol:
+ match monitor_type:
+ case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
+ return MONITOR_SYMBOLS['field_e']
+
+ case 'FieldTime':
+ return MONITOR_SYMBOLS['field']
+
+ case 'Flux':
+ return MONITOR_SYMBOLS['flux']
+
+ case 'FluxTime':
+ return MONITOR_SYMBOLS['flux']
+
+ case 'FieldProjectionAngle':
+ return MONITOR_SYMBOLS['field']
+
+ case 'FieldProjectionKSpace':
+ return MONITOR_SYMBOLS['field']
+
+ case 'Diffraction':
+ return MONITOR_SYMBOLS['field']
+
+ return None
+
+
+def _extract_info(
+ example_xarr: xarray.DataArray,
+ monitor_type: str,
+ monitor_attrs: tuple[str, ...],
+ batch_dims: dict[sim_symbols.SimSymbol, ct.RangeFlow | ct.ArrayFlow],
+) -> ct.InfoFlow | None:
+ log.debug([monitor_type, monitor_attrs, batch_dims])
+ mk_idx_array = functools.partial(_mk_idx_array, example_xarr)
+ match monitor_type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return ct.InfoFlow(
- dims={
- symbols['x']: mk_idx_array('x'),
- symbols['y']: mk_idx_array('y'),
- symbols['z']: mk_idx_array('z'),
- symbols['f']: mk_idx_array('f'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['EH*']: monitor_attrs,
+ MONITOR_SYMBOLS['x']: mk_idx_array('x'),
+ MONITOR_SYMBOLS['y']: mk_idx_array('y'),
+ MONITOR_SYMBOLS['z']: mk_idx_array('z'),
+ MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['field_e'],
)
case 'FieldTime':
return ct.InfoFlow(
- dims={
- symbols['x']: mk_idx_array('x'),
- symbols['y']: mk_idx_array('y'),
- symbols['z']: mk_idx_array('z'),
- symbols['t']: mk_idx_array('t'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['EH*']: monitor_attrs,
+ MONITOR_SYMBOLS['x']: mk_idx_array('x'),
+ MONITOR_SYMBOLS['y']: mk_idx_array('y'),
+ MONITOR_SYMBOLS['z']: mk_idx_array('z'),
+ MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['field'],
)
case 'Flux':
return ct.InfoFlow(
- dims={
- symbols['f']: mk_idx_array('f'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['flux'],
)
case 'FluxTime':
return ct.InfoFlow(
- dims={
- symbols['t']: mk_idx_array('t'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['flux'],
)
case 'FieldProjectionAngle':
return ct.InfoFlow(
- dims={
- symbols['r']: mk_idx_array('r'),
- symbols['theta']: mk_idx_array('theta'),
- symbols['phi']: mk_idx_array('phi'),
- symbols['f']: mk_idx_array('f'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['EH*']: monitor_attrs,
+ MONITOR_SYMBOLS['r']: mk_idx_array('r'),
+ MONITOR_SYMBOLS['theta']: mk_idx_array('theta'),
+ MONITOR_SYMBOLS['phi']: mk_idx_array('phi'),
+ MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['field'],
)
case 'FieldProjectionKSpace':
return ct.InfoFlow(
- dims={
- symbols['ux']: mk_idx_array('ux'),
- symbols['uy']: mk_idx_array('uy'),
- symbols['r']: mk_idx_array('r'),
- symbols['f']: mk_idx_array('f'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['EH*']: monitor_attrs,
+ MONITOR_SYMBOLS['ux']: mk_idx_array('ux'),
+ MONITOR_SYMBOLS['uy']: mk_idx_array('uy'),
+ MONITOR_SYMBOLS['r']: mk_idx_array('r'),
+ MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['field'],
)
case 'Diffraction':
return ct.InfoFlow(
- dims={
- symbols['orders_x']: mk_idx_array('orders_x'),
- symbols['orders_y']: mk_idx_array('orders_y'),
- symbols['f']: mk_idx_array('f'),
+ dims=batch_dims
+ | {
+ MONITOR_SYMBOLS['EH*']: monitor_attrs,
+ MONITOR_SYMBOLS['orders_x']: mk_idx_array('orders_x'),
+ MONITOR_SYMBOLS['orders_y']: mk_idx_array('orders_y'),
+ MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
- output=symbols[monitor_attr],
+ output=MONITOR_SYMBOLS['field'],
)
- return None
+ raise TypeError
+
+
+def extract_monitor_xarrs(
+ monitor_datas: dict[RealizedSymsVals, typ.Any], monitor_attrs: tuple[str, ...]
+) -> dict[RealizedSymsVals, ct.InfoFlow]:
+ return {
+ syms_vals: {
+ monitor_attr: getattr(monitor_data, monitor_attr, None)
+ for monitor_attr in monitor_attrs
+ }
+ for syms_vals, monitor_data in monitor_datas.items()
+ }
+
+
+def extract_info(
+ monitor_datas: dict[RealizedSymsVals, typ.Any], monitor_attrs: tuple[str, ...]
+) -> dict[RealizedSymsVals, ct.InfoFlow]:
+ """Extract an InfoFlow describing monitor data from a batch of simulations."""
+ # Extract Dimension from Batched Values
+ ## -> Comb the data to expose each symbol's realized values as an array.
+ ## -> Each symbol: array can then become a dimension.
+ ## -> These are the "batch dimensions", which allows indexing across sims.
+ ## -> The retained sim symbol provides semantic index coordinates.
+ example_syms_vals = next(iter(monitor_datas.keys()))
+ syms = example_syms_vals[0] if example_syms_vals != () else ()
+ vals_per_sym_pos = (
+ [vals for _, vals in monitor_datas] if example_syms_vals != () else []
+ )
+
+ _batch_dims = dict(
+ zip(
+ syms,
+ zip(*vals_per_sym_pos, strict=True),
+ strict=True,
+ )
+ )
+ batch_dims = {
+ sym: ct.RangeFlow.try_from_array(
+ ct.ArrayFlow(
+ jax_bytes=vals,
+ unit=sym.unit,
+ is_sorted=True,
+ )
+ )
+ for sym, vals in _batch_dims.items()
+ }
+
+ # Extract Example Monitor Data | XArray
+ ## -> We presume all monitor attributes have the exact same dims + output.
+ ## -> Because of this, we only need one "example" xarray.
+ ## -> This xarray will be used to extract dimensional coordinates...
+ ## -> ...Presuming that these coords will generalize.
+ example_monitor_data = next(iter(monitor_datas.values()))
+ monitor_datas_xarrs = extract_monitor_xarrs(monitor_datas, monitor_attrs)
+
+ # Extract XArray for Each Monitor Attribute
+ example_monitor_data_xarrs = next(iter(monitor_datas_xarrs.values()))
+ example_xarr = next(iter(example_monitor_data_xarrs.values()))
+
+ # Extract InfoFlow of First
+ ## -> All of the InfoFlows should be identical...
+ ## -> ...Apart from the batched dimensions.
+ return _extract_info(
+ example_xarr,
+ example_monitor_data.type.removesuffix('Data'),
+ monitor_attrs,
+ batch_dims,
+ )
####################
@@ -224,73 +341,139 @@ class ExtractDataNode(base.MaxwellSimNode):
bl_label = 'Extract'
input_socket_sets: typ.ClassVar = {
- 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
+ 'Single': {
+ 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
+ },
+ 'Batch': {
+ 'Sim Datas': sockets.MaxwellFDTDSimDataSocketDef(active_kind=FK.Array),
+ },
}
+ # output_sockets: typ.ClassVar = {
+ # 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
+ # }
output_socket_sets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Single': {
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
+ 'Log': sockets.StringSocketDef(),
+ },
+ 'Batch': {
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
+ 'Logs': sockets.StringSocketDef(active_kind=FK.Array),
+ },
}
####################
- # - Properties: Monitor Name
+ # - Properties: Sim Datas
####################
@events.on_value_changed(
- socket_name='Sim Data',
- input_sockets={'Sim Data'},
- input_sockets_optional={'Sim Data': True},
+ socket_name={'Sim Data': FK.Value, 'Sim Datas': FK.Array},
)
- def on_sim_data_changed(self, input_sockets) -> None: # noqa: D102
- has_sim_data = not ct.FlowSignal.check(input_sockets['Sim Data'])
- if has_sim_data:
- self.sim_data = bl_cache.Signal.InvalidateCache
+ def on_sim_datas_changed(self) -> None: # noqa: D102
+ self.sim_datas = bl_cache.Signal.InvalidateCache
- @bl_cache.cached_bl_property()
- def sim_data(self) -> td.SimulationData | None:
+ @bl_cache.cached_bl_property(depends_on={'active_socket_set'})
+ def sim_datas(self) -> list[td.SimulationData] | None:
"""Extracts the simulation data from the input socket.
Return:
Either the simulation data, if available, or None.
"""
- sim_data = self._compute_input(
- 'Sim Data', kind=ct.FlowKind.Value, optional=True
- )
- has_sim_data = not ct.FlowSignal.check(sim_data)
- if has_sim_data:
- return sim_data
+ ## TODO: Check that syms are identical for all (aka. that we have a batch)
+ if self.active_socket_set == 'Single':
+ sim_data = self._compute_input('Sim Data', kind=FK.Value)
+ has_sim_data = not FS.check(sim_data)
+ if has_sim_data:
+ # Embedded Symbolic Realizations
+ ## -> ['realizations'] contains a 2-tuple
+ ## -> First should be the dict-dump of a SimSymbol.
+ ## -> Second should be either a value, or a list of values.
+ if 'realizations' in sim_data.attrs:
+ raw_realizations = sim_data.attrs['realizations']
+ syms_vals = {
+ sim_symbols.SimSymbol(**raw_sym): raw_val
+ if not isinstance(raw_val, tuple | list)
+ else jnp.array(raw_val)
+ for raw_sym, raw_val in raw_realizations
+ }
+ return {syms_vals: sim_data}
+
+ # No Embedded Realizations
+ return {(): sim_data}
+
+ if self.active_socket_set == 'Batch':
+ _sim_datas = self._compute_input('Sim Datas', kind=FK.Value)
+ has_sim_datas = not FS.check(_sim_datas)
+ if has_sim_datas:
+ sim_datas = {}
+ for sim_data in sim_datas:
+ # Embedded Symbolic Realizations
+ ## -> ['realizations'] contains a 2-tuple
+ ## -> First should be the dict-dump of a SimSymbol.
+ ## -> Second should be either a value, or a list of values.
+ if 'realizations' in sim_data.attrs:
+ raw_realizations = sim_data.attrs['realizations']
+ syms = {
+ sim_symbols.SimSymbol(**raw_sym): raw_val
+ if not isinstance(raw_val, tuple | list)
+ else jnp.array(raw_val)
+ for raw_sym, raw_val in raw_realizations
+ }
+ sim_datas |= {syms_vals: sim_data}
+
+ # No Embedded Realizations
+ sim_datas |= {(): sim_data}
return None
- @bl_cache.cached_bl_property(depends_on={'sim_data'})
- def sim_data_monitor_nametype(self) -> dict[str, str] | None:
- """Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
+ @bl_cache.cached_bl_property(depends_on={'sim_datas'})
+ def example_sim_data(self) -> list[td.SimulationData] | None:
+ """Extracts a single, example simulation data from the input socket.
+
+ All simulation datas share certain properties, ex. names and types of monitors.
+ Therefore, we may often only need an example simulation data object.
+
+ Return:
+ Either the simulation data, if available, or None.
+ """
+ if self.sim_datas:
+ return next(iter(self.sim_datas.values()))
+ return None
+
+ ####################
+ # - Properties: Monitor Name
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'example_sim_data'})
+ def monitor_types(self) -> dict[str, str] | None:
+ """Dictionary from monitor names on `self.sim_datas` to their associated type name (with suffix 'Data' removed).
Return:
The name to type of monitors in the simulation data.
"""
- if self.sim_data is not None:
+ if self.example_sim_data is not None:
return {
monitor_name: monitor_data.type.removesuffix('Data')
- for monitor_name, monitor_data in self.sim_data.monitor_data.items()
+ for monitor_name, monitor_data in self.example_sim_data.monitor_data.items()
}
return None
monitor_name: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_monitor_names(),
- cb_depends_on={'sim_data_monitor_nametype'},
+ cb_depends_on={'monitor_types'},
)
def search_monitor_names(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
Notes:
- Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
+ Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametypes`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns:
Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
"""
- if self.sim_data_monitor_nametype is not None:
+ if self.monitor_types is not None:
return [
(
monitor_name,
@@ -300,12 +483,32 @@ class ExtractDataNode(base.MaxwellSimNode):
i,
)
for i, (monitor_name, monitor_type) in enumerate(
- self.sim_data_monitor_nametype.items()
+ self.monitor_types.items()
)
]
return []
+ ####################
+ # - Properties: Monitor Information
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sim_datas', 'monitor_name'})
+ def monitor_datas(self) -> SimDataArrayInfo | None:
+ """Extract the currently selected monitor's data from all simulation datas in the batch."""
+ if self.sim_datas is not None and self.monitor_name is not None:
+ return {
+ syms_vals: sim_data.monitor_data.get(self.monitor_name)
+ for syms_vals, sim_data in self.sim_datas.items()
+ }
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'example_sim_data', 'monitor_name'})
+ def valid_monitor_attrs(self) -> SimDataArrayInfo | None:
+ """Valid attributes of the monitor, from the example sim data under the presumption that the entire batch shares the same attribute validity."""
+ if self.example_sim_data is not None and self.monitor_name is not None:
+ return valid_monitor_attrs(self.example_sim_data, self.monitor_name)
+ return None
+
####################
# - UI
####################
@@ -315,9 +518,7 @@ class ExtractDataNode(base.MaxwellSimNode):
Notes:
Called by Blender to determine the text to place in the node's header.
"""
- has_sim_data = self.sim_data_monitor_nametype is not None
-
- if has_sim_data:
+ if self.monitor_name is not None:
return f'Extract: {self.monitor_name}'
return self.bl_label
@@ -335,114 +536,181 @@ class ExtractDataNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- props={'monitor_name'},
- input_sockets={'Sim Data'},
- input_socket_kinds={'Sim Data': ct.FlowKind.Value},
+ props={'monitor_datas', 'valid_monitor_attrs'},
)
- def compute_expr(
- self, props: dict, input_sockets: dict
- ) -> ct.FuncFlow | ct.FlowSignal:
- sim_data = input_sockets['Sim Data']
- monitor_name = props['monitor_name']
+ def compute_extracted_data_func(self, props) -> ct.FuncFlow | FS:
+ """Aggregates the selected monitor's data across all batched symbolic realizations, into a single FuncFlow."""
+ monitor_datas = props['monitor_datas']
+ valid_monitor_attrs = props['valid_monitor_attrs']
- has_sim_data = not ct.FlowSignal.check(sim_data)
+ if monitor_datas is not None and valid_monitor_attrs is not None:
+ monitor_datas_xarrs = extract_monitor_xarrs(
+ monitor_datas, valid_monitor_attrs
+ )
- if has_sim_data and monitor_name is not None:
- monitor_data = sim_data.get(monitor_name)
- if monitor_data is not None:
- # Extract Valid Index Labels
- ## -> The first output axis will be integer-indexed.
- ## -> Each integer will have a string label.
- ## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
- idx_labels = valid_monitor_attrs(sim_data, monitor_name)
+ example_monitor_data = next(iter(monitor_datas.values()))
+ monitor_type = example_monitor_data.type.removesuffix('Data')
+ output_sym = output_symbol_by_type(monitor_type)
- # Extract Info
- ## -> We only need the output symbol.
- ## -> All labelled outputs have the same output SimSymbol.
- info = extract_info(monitor_data, idx_labels[0])
+ # Stack Inner Dimensions: components | *
+ ## -> Each realization maps to exactly one xarray.
+ ## -> We extract its data, and wrap it into a FuncFlow.
+ ## -> This represents the "inner" dimensions, with components|data.
+ ## -> We then attach this singular FuncFlow to that realization.
+ inner_funcs = {}
+ for syms_vals, attr_xarrs in monitor_datas_xarrs.items():
+ # XArray Capture Function
+ ## -> We can't generally capture a loop variable inline.
+ ## -> By making a new function, we get a new scope.
+ def _xarr_values(xarr):
+ return lambda: xarr.values
- # Generate FuncFlow Per Index Label
- ## -> We extract each XArray as an attribute of monitor_data.
- ## -> We then bind its values into a unique func_flow.
- ## -> This lets us 'stack' then all along the first axis.
- func_flows = []
- for idx_label in idx_labels:
- xarr = getattr(monitor_data, idx_label)
- func_flows.append(
- ct.FuncFlow(
- func=lambda xarr=xarr: xarr.values,
- supports_jax=True,
- )
+ # Bind XArray Values into FuncFlows
+ ## -> Each monitor component has an xarray.
+ funcs = [
+ ct.FuncFlow(
+ func=_xarr_values(xarr),
+ func_output=output_sym,
+ supports_jax=True,
+ )
+ for xarr in attr_xarrs.values()
+ ]
+ log.critical(['FUNCS', funcs])
+ # Single Component: No Stacking of Dimensions - *
+ if len(funcs) == 1:
+ inner_funcs[syms_vals] = funcs[0]
+
+ # Many Components: Stack Dimensions - components | *
+ else:
+ inner_funcs[syms_vals] = functools.reduce(
+ lambda a, b: a | b, funcs
+ ).compose_within(
+ lambda els: jnp.stack(els, axis=0),
+ enclosing_func_output=output_sym,
)
- # Concatenate and Stack Unified FuncFlow
- ## -> First, 'reduce' lets us __or__ all the FuncFlows together.
- ## -> Then, 'compose_within' lets us stack them along axis=0.
- ## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
- return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
- lambda data: jnp.stack(data, axis=0),
- func_output=info.output,
- )
- return ct.FlowSignal.FlowPending
- return ct.FlowSignal.FlowPending
+ # Stack Batch Dims: vals0 | vals1 | ... | valsN | components | *
+ ## -> We stack the inner-dimensional object together, backwards.
+ ## -> Each stack prepends a new dimension.
+ ## -> Here, everything is integer-indexed.
+ ## -> But in the InfoFlow, a similar process contextualizes idxs.
+ example_syms_vals = next(iter(monitor_datas.keys()))
+ syms = example_syms_vals[0] if example_syms_vals != () else ()
+ outer_funcs = inner_funcs
+ log.critical(['INNER FUNCS', inner_funcs])
+ for _, axis in reversed(list(enumerate(syms))):
+ log.critical([axis, outer_funcs])
+ new_outer_funcs = {}
+ # Collect Funcs Along *vals[axis] | ...
+ ## -> Grab ONLY up to 'axis' syms_vals.
+ ## -> '|' all functions that share axis-deficient syms_vals.
+ ## -> Thus, we '|' functions along the last axis.
+ for unreduced_syms_vals, func in outer_funcs:
+ reduced_syms_vals = unreduced_syms_vals[:axis]
+ if reduced_syms_vals in new_outer_funcs:
+ new_outer_funcs[reduced_syms_vals] = (
+ new_outer_funcs[reduced_syms_vals] | func
+ )
+ else:
+ new_outer_funcs[reduced_syms_vals] = func
- ####################
- # - FlowKind.Params
- ####################
- @events.computes_output_socket(
- 'Expr',
- kind=ct.FlowKind.Params,
- input_sockets={'Sim Data'},
- input_socket_kinds={'Sim Data': ct.FlowKind.Params},
- )
- def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
- """Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
+ # Aggregate All Collected Funcs
+ ## -> Any functions that went through a | are stacked.
+ ## -> Otherwise, just add a len=1 dimension.
+ new_reduced_outer_funcs = {
+ reduced_syms_vals: (
+ combined_func.compose_within(
+ lambda els: jnp.stack(els, axis=0),
+ enclosing_func_output=output_sym,
+ )
+ if combined_func.is_concatenated ## Went through a |
+ else combined_func.compose_within(
+ lambda el: jnp.expand_dims(el, axis=0),
+ enclosing_func_output=output_sym,
+ )
+ )
+ for reduced_syms_vals, combined_func in new_outer_funcs.items()
+ }
- Returns:
- A completely empty `ParamsFlow`, ready to be composed.
- """
- sim_params = input_sockets['Sim Data']
- has_sim_params = not ct.FlowSignal.check(sim_params)
-
- if has_sim_params:
- return sim_params
- return ct.ParamsFlow()
+ # Reset Outer Funcs to Axis-Deficient Reduction
+ ## -> This effectively removes + aggregates the last axis.
+ ## -> When the loop is done, only {(): val} will be left.
+ outer_funcs = new_reduced_outer_funcs
+ return next(iter(outer_funcs.values()))
+ return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
# Loaded
- props={'monitor_name'},
- input_sockets={'Sim Data'},
- input_socket_kinds={'Sim Data': ct.FlowKind.Value},
+ props={'monitor_datas', 'valid_monitor_attrs'},
)
- def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
+ def compute_extracted_data_info(self, props) -> ct.InfoFlow | FS:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
Returns:
- Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
+ Information describing the `Data:Func`, if available, else `FS.FlowPending`.
"""
- sim_data = input_sockets['Sim Data']
- monitor_name = props['monitor_name']
+ monitor_datas = props['monitor_datas']
+ valid_monitor_attrs = props['valid_monitor_attrs']
- has_sim_data = not ct.FlowSignal.check(sim_data)
+ if monitor_datas is not None and valid_monitor_attrs is not None:
+ return extract_info(monitor_datas, valid_monitor_attrs)
+ return FS.FlowPending
- if not has_sim_data or monitor_name is None:
- return ct.FlowSignal.FlowPending
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Expr',
+ kind=FK.Params,
+ )
+ def compute_params(self) -> ct.ParamsFlow:
+ """Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
- # Extract Data
- ## -> All monitor_data. have the exact same InfoFlow.
- ## -> So, just construct an InfoFlow w/prepended labelled dimension.
- monitor_data = sim_data.get(monitor_name)
- idx_labels = valid_monitor_attrs(sim_data, monitor_name)
- info = extract_info(monitor_data, idx_labels[0])
+ Returns:
+ A completely empty `ParamsFlow`, ready to be composed.
+ """
+ return ct.ParamsFlow()
- return info.prepend_dim(sim_symbols.idx, idx_labels)
+ ####################
+ # - Log: FlowKind.Value|Array
+ ####################
+ @events.computes_output_socket(
+ 'Log',
+ kind=FK.Value,
+ # Loaded
+ props={'sim_datas'},
+ )
+ def compute_extracted_log(self, props) -> str | FS:
+ """Extract the log from a single simulation that ran."""
+ sim_datas = props['sim_datas']
+ if sim_datas is not None and len(sim_datas) == 1:
+ sim_data = next(iter(sim_datas.values()))
+ if sim_data.log is not None:
+ return sim_data.log
+ return FS.FlowPending
+
+ @events.computes_output_socket(
+ 'Log',
+ kind=FK.Array,
+ # Loaded
+ props={'sim_datas'},
+ )
+ def compute_extracted_logs(self, props) -> dict[RealizedSymsVals, str] | FS:
+ """Extract the log from all simulation that ran in the batch."""
+ sim_datas = props['sim_datas']
+ if sim_datas is not None and sim_datas:
+ return {
+ syms_vals: sim_data.log if sim_data.log is not None else ''
+ for syms_vals, sim_data in sim_datas.items()
+ }
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py
index a1fb2e6..7f20a58 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/__init__.py
@@ -14,19 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from . import filter_math, map_math, operate_math, transform_math
+from . import filter_math, map_math, operate_math, reduce_math, transform_math
BL_REGISTER = [
*operate_math.BL_REGISTER,
*map_math.BL_REGISTER,
*filter_math.BL_REGISTER,
- # *reduce_math.BL_REGISTER,
+ *reduce_math.BL_REGISTER,
*transform_math.BL_REGISTER,
]
BL_NODES = {
**operate_math.BL_NODES,
**map_math.BL_NODES,
**filter_math.BL_NODES,
- # **reduce_math.BL_NODES,
+ **reduce_math.BL_NODES,
**transform_math.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
index 6c02ec1..d2fccdc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/filter_math.py
@@ -22,13 +22,20 @@ import typing as typ
import bpy
import sympy as sp
-from blender_maxwell.utils import bl_cache, sim_symbols
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
from ... import base, events
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+FO = math_system.FilterOperation
+MT = spux.MathType
+
class FilterMathNode(base.MaxwellSimNode):
r"""Applies a function that operates on the shape of the array.
@@ -52,10 +59,10 @@ class FilterMathNode(base.MaxwellSimNode):
bl_label = 'Filter Math'
input_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_func_ui=False),
}
output_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@@ -63,50 +70,43 @@ class FilterMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
- socket_name={'Expr'},
+ socket_name={'Expr': FK.Info},
# Loaded
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
- input_sockets_optional={'Expr': True},
+ inscks_kinds={'Expr': FK.Info},
+ input_sockets_optional={'Expr'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
- def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
- has_info = not ct.FlowSignal.check(input_sockets['Expr'])
-
- info_pending = ct.FlowSignal.check_single(
- input_sockets['Expr'], ct.FlowSignal.FlowPending
- )
+ def on_input_expr_changed(self, input_sockets) -> None: # noqa: D102
+ info = input_sockets['Expr']
+ has_info = not FS.check(info)
+ info_pending = FS.check_single(info, FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
- info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
- has_info = not ct.FlowSignal.check(info)
+ """Retrieve the input expression's `InfoFlow`."""
+ info = self._compute_input('Expr', kind=FK.Info)
+ has_info = not FS.check(info)
if has_info:
return info
-
return None
####################
# - Properties: Operation
####################
- operation: math_system.FilterOperation = bl_cache.BLField(
+ operation: FO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
+ """Determine all valid operations from the input expression."""
if self.expr_info is not None:
- return [
- operation.bl_enum_element(i)
- for i, operation in enumerate(
- math_system.FilterOperation.by_info(self.expr_info)
- )
- ]
+ return FO.bl_enum_elements(self.expr_info)
return []
####################
@@ -122,6 +122,7 @@ class FilterMathNode(base.MaxwellSimNode):
)
def search_dims(self) -> list[ct.BLEnumElement]:
+ """Determine all valid dimensions from the input expression."""
if self.expr_info is not None and self.operation is not None:
return [
(dim.name, dim.name_pretty, dim.name, '', i)
@@ -131,16 +132,32 @@ class FilterMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
def dim_0(self) -> sim_symbols.SimSymbol | None:
+ """The first currently active dimension, if any is selected; otherwise `None`."""
if self.expr_info is not None and self.active_dim_0 is not None:
return self.expr_info.dim_by_name(self.active_dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
def dim_1(self) -> sim_symbols.SimSymbol | None:
+ """The second currently active dimension, if any is selected; otherwise `None`."""
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_by_name(self.active_dim_1)
return None
+ @bl_cache.cached_bl_property(depends_on={'dim_0'})
+ def axis_0(self) -> sim_symbols.SimSymbol | None:
+ """The first currently active axis, derived from `self.dim_0`."""
+ if self.expr_info is not None and self.dim_0 is not None:
+ return self.expr_info.dim_axis(self.dim_0)
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'dim_1'})
+ def axis_1(self) -> sim_symbols.SimSymbol | None:
+ """The second currently active dimension, if any is selected; otherwise `None`."""
+ if self.expr_info is not None and self.active_dim_1 is not None:
+ return self.expr_info.dim_axis(self.dim_1)
+ return None
+
####################
# - Properties: Slice
####################
@@ -149,8 +166,12 @@ class FilterMathNode(base.MaxwellSimNode):
####################
# - UI
####################
- def draw_label(self):
- FO = math_system.FilterOperation
+ def draw_label(self): # noqa: PLR0911
+ """Show the active filter operation in the node's header label.
+
+ Notes:
+ Called by Blender to determine the text to place in the node's header.
+ """
match self.operation:
# Slice
case FO.SliceIdx:
@@ -163,10 +184,8 @@ class FilterMathNode(base.MaxwellSimNode):
case FO.Pin:
return f'Filter: Pin {self.active_dim_0}[...]'
case FO.PinIdx:
- pin_idx_axis = self._compute_input(
- 'Axis', kind=ct.FlowKind.Value, optional=True
- )
- has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
+ pin_idx_axis = self._compute_input('Index', kind=FK.Value)
+ has_pin_idx_axis = not FS.check(pin_idx_axis)
if has_pin_idx_axis:
return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
return self.bl_label
@@ -179,6 +198,11 @@ class FilterMathNode(base.MaxwellSimNode):
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's properties inside of the node itself.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
layout.prop(self, self.blfields['operation'], text='')
if self.operation is not None:
@@ -190,7 +214,7 @@ class FilterMathNode(base.MaxwellSimNode):
row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['active_dim_1'], text='')
- if self.operation is math_system.FilterOperation.SliceIdx:
+ if self.operation is FO.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
####################
@@ -198,220 +222,190 @@ class FilterMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
- socket_name='Expr',
- prop_name={'operation', 'dim_0', 'dim_1'},
+ socket_name={'Expr': FK.Info},
+ prop_name={'operation', 'dim_0'},
# Loaded
- props={'operation', 'dim_0', 'dim_1'},
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
+ props={'operation', 'dim_0'},
+ inscks_kinds={'Expr': FK.Info},
+ input_sockets_optional={'Expr'},
)
- def on_pin_factors_changed(self, props: dict, input_sockets: dict):
- """Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`.
-
- To "pin" an axis, a particular index must be chosen to "extract".
- One might choose axes of length 1 ("squeeze"), choose a particular index, or choose a coordinate that maps to a particular index.
-
- Those last two options requires more information from the user: Which index?
- Which coordinate?
- To answer these questions, we create an appropriate loose input socket containing this data, so the user can make their decision.
- """
+ def on_pin_factors_changed(self, props, input_sockets) -> None:
+ """Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`."""
info = input_sockets['Expr']
- has_info = not ct.FlowSignal.check(info)
- if not has_info:
- return
+ has_info = not FS.check(info)
dim_0 = props['dim_0']
- # Loose Sockets: Pin Dim by-Value
- ## -> Works with continuous / discrete indexes.
- ## -> The user will be given a socket w/correct mathtype, unit, etc. .
- if (
- props['operation'] is math_system.FilterOperation.Pin
- and dim_0 is not None
- and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
- ):
- dim = dim_0
- current_bl_socket = self.loose_input_sockets.get('Value')
-
- if (
- current_bl_socket is None
- or current_bl_socket.active_kind != ct.FlowKind.Value
- or current_bl_socket.size is not spux.NumberSize1D.Scalar
- or current_bl_socket.physical_type != dim.physical_type
- or current_bl_socket.mathtype != dim.mathtype
+ operation = props['operation']
+ match operation:
+ case FO.Pin if (
+ has_info
+ and dim_0 is not None
+ and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
self.loose_input_sockets = {
'Value': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Value,
- physical_type=dim.physical_type,
- mathtype=dim.mathtype,
- default_unit=dim.unit,
+ active_kind=FK.Value,
+ **dim_0.expr_info,
),
}
- # Loose Sockets: Pin Dim by-Value
- ## -> Works with discrete points / labelled integers.
- elif (
- props['operation'] is math_system.FilterOperation.PinIdx
- and dim_0 is not None
- and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
- ):
- dim = dim_0
- current_bl_socket = self.loose_input_sockets.get('Axis')
- if (
- current_bl_socket is None
- or current_bl_socket.active_kind != ct.FlowKind.Value
- or current_bl_socket.size is not spux.NumberSize1D.Scalar
- or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
- or current_bl_socket.mathtype != spux.MathType.Integer
+ case FO.PinIdx if (
+ has_info
+ and dim_0 is not None
+ and (info.has_idx_labels(dim_0) or info.has_idx_discrete(dim_0))
):
self.loose_input_sockets = {
- 'Axis': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Value,
- mathtype=spux.MathType.Integer,
- )
+ 'Index': sockets.ExprSocketDef(
+ active_kind=FK.Value,
+ output_name=dim_0.expr_info['output_name'],
+ mathtype=MT.Integer,
+ abs_min=0,
+ abs_max=len(info.dims[dim_0]) - 1,
+ ),
}
- # No Loose Value: Remove Input Sockets
- elif self.loose_input_sockets:
- self.loose_input_sockets = {}
+ case _ if self.loose_input_sockets:
+ self.loose_input_sockets = {}
####################
- # - FlowKind.Value|Func
+ # - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Func,
- props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}},
+ kind=FK.Func,
+ # Loaded
+ props={'operation', 'axis_0', 'axis_1', 'slice_tuple'},
+ inscks_kinds={'Expr': FK.Func},
)
- def compute_lazy_func(self, props: dict, input_sockets: dict):
- operation = props['operation']
- lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
- info = input_sockets['Expr'][ct.FlowKind.Info]
+ def compute_func(self, props, input_sockets) -> None:
+ """Filter operation on lazy-defined input expression."""
+ lazy_func = input_sockets['Expr']
- has_lazy_func = not ct.FlowSignal.check(lazy_func)
- has_info = not ct.FlowSignal.check(info)
-
- dim_0 = props['dim_0']
- dim_1 = props['dim_1']
+ axis_0 = props['axis_0']
+ axis_1 = props['axis_1']
slice_tuple = props['slice_tuple']
- if (
- has_lazy_func
- and has_info
- and operation is not None
- and operation.are_dims_valid(info, dim_0, dim_1)
- ):
- axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
- axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
- return lazy_func.compose_within(
- operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
- enclosing_func_args=operation.func_args,
- enclosing_func_output=info.output,
- supports_jax=True,
+ operation = props['operation']
+ if operation is not None:
+ new_func = operation.transform_func(
+ lazy_func, axis_0, axis_1=axis_1, slice_tuple=slice_tuple
)
- return ct.FlowSignal.FlowPending
+ if new_func is not None:
+ return new_func
+
+ return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
+ # Loaded
props={
'dim_0',
'dim_1',
'operation',
'slice_tuple',
},
- input_sockets={'Expr', 'Dim'},
- input_socket_kinds={
- 'Expr': ct.FlowKind.Info,
- 'Dim': {ct.FlowKind.Func, ct.FlowKind.Params, ct.FlowKind.Info},
+ inscks_kinds={
+ 'Expr': FK.Info,
+ 'Value': {FK.Func, FK.Params},
+ 'Index': {FK.Func, FK.Params},
},
- input_sockets_optional={'Dim': True},
+ input_sockets_optional={'Index', 'Value'},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
- operation = props['operation']
+ """Transform `InfoFlow` based on the current filtering operation."""
info = input_sockets['Expr']
- has_info = not ct.FlowSignal.check(info)
-
- # Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
- slice_tuple = props['slice_tuple']
- if has_info and operation is not None:
- return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
- return ct.FlowSignal.FlowPending
+
+ operation = props['operation']
+ match operation:
+ # Slice
+ case FO.Slice | FO.SliceIdx if dim_0 is not None:
+ slice_tuple = props['slice_tuple']
+ return operation.transform_info(info, dim_0, slice_tuple=slice_tuple)
+
+ # Pin
+ case FO.PinLen1 if dim_0 is not None:
+ return operation.transform_info(info, dim_0)
+
+ case FO.Pin if dim_0 is not None:
+ pinned_value = events.realize_known(
+ input_sockets['Value'], conformed=True
+ )
+ if pinned_value is not None:
+ nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
+ pinned_value, require_sorted=True
+ )
+ return operation.transform_info(
+ info, dim_0, pin_idx=nearest_idx_to_value
+ )
+
+ case FO.PinIdx if dim_0 is not None:
+ pinned_idx = int(events.realize_known(input_sockets['Index']))
+ if pinned_idx is not None:
+ return operation.transform_info(info, dim_0, pin_idx=pinned_idx)
+
+ # Swizzle
+ case FO.Swap if dim_0 is not None and dim_1 is not None:
+ return operation.transform_info(info, dim_0, dim_1)
+
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
+ # Loaded
props={'dim_0', 'dim_1', 'operation'},
- input_sockets={'Expr', 'Value', 'Axis'},
- input_socket_kinds={
- 'Expr': {ct.FlowKind.Info, ct.FlowKind.Params},
+ inscks_kinds={
+ 'Value': {FK.Func, FK.Params},
+ 'Index': {FK.Func, FK.Params},
+ 'Expr': {FK.Info, FK.Params},
},
- input_sockets_optional={'Value': True, 'Axis': True},
+ input_sockets_optional={'Value', 'Index'},
)
- def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow:
- operation = props['operation']
- info = input_sockets['Expr'][ct.FlowKind.Info]
- params = input_sockets['Expr'][ct.FlowKind.Params]
+ def compute_params(self, props, input_sockets) -> ct.ParamsFlow:
+ """Compute tracked function argument parameters of input parameters."""
+ info = input_sockets['Expr'][FK.Info]
+ params = input_sockets['Expr'][FK.Params]
- has_info = not ct.FlowSignal.check(info)
- has_params = not ct.FlowSignal.check(params)
-
- # Dimension(s)
dim_0 = props['dim_0']
- dim_1 = props['dim_1']
- if (
- has_info
- and has_params
- and operation is not None
- and operation.are_dims_valid(info, dim_0, dim_1)
- ):
- # Retrieve Pinned Value
- pinned_value = input_sockets['Value']
- has_pinned_value = not ct.FlowSignal.check(pinned_value)
- pinned_axis = input_sockets['Axis']
- has_pinned_axis = not ct.FlowSignal.check(pinned_axis)
+ operation = props['operation']
+ match operation:
+ # *
+ case FO.Slice | FO.SliceIdx | FO.PinLen1 | FO.Swap:
+ return params
- # Pin by-Value: Compute Nearest IDX
- ## -> Presume a sorted index array to be able to use binary search.
- if (
- props['operation'] is math_system.FilterOperation.Pin
- and has_pinned_value
- ):
- nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
- pinned_value, require_sorted=True
+ # Pin
+ case FO.Pin:
+ pinned_value = events.realize_known(
+ input_sockets['Value'], conformed=True
)
+ if pinned_value is not None:
+ nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
+ pinned_value, require_sorted=True
+ )
+ return params.compose_within(
+ enclosing_func_args=(sp.Integer(nearest_idx_to_value),),
+ )
- return params.compose_within(
- enclosing_arg_targets=[sim_symbols.idx(None)],
- enclosing_func_args=[sp.S(nearest_idx_to_value)],
- )
+ case FO.PinIdx:
+ pinned_idx = events.realize_known(input_sockets['Index'])
+ if pinned_idx is not None:
+ return params.compose_within(
+ enclosing_func_args=(sp.Integer(pinned_idx),),
+ )
- # Pin by-Index
- if (
- props['operation'] is math_system.FilterOperation.PinIdx
- and has_pinned_axis
- ):
- return params.compose_within(
- enclosing_arg_targets=[sim_symbols.idx(None)],
- enclosing_func_args=[sp.S(pinned_axis)],
- )
-
- return params
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
index 4959822..d184fbe 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py
@@ -21,6 +21,7 @@ import typing as typ
import bpy
from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
@@ -28,6 +29,11 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MO = math_system.MapOperation
+MT = spux.MathType
+
class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data.
@@ -102,7 +108,7 @@ class MapMathNode(base.MaxwellSimNode):
The name and type of the available symbol is clearly shown, and most valid `sympy` expressions that you would expect to work, should work.
Use of expressions generally imposes no performance penalty: Just like the baked-in operations, it is compiled to a high-performance `jax` function.
- Thus, it participates in the `ct.FlowKind.Func` composition chain.
+ Thus, it participates in the `FK.Func` composition chain.
Attributes:
@@ -113,69 +119,76 @@ class MapMathNode(base.MaxwellSimNode):
bl_label = 'Map Math'
input_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
- # - Properties
+ # - Properties: Incoming InfoFlow
####################
@events.on_value_changed(
# Trigger
- socket_name={'Expr'},
+ socket_name={'Expr': FK.Info},
# Loaded
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
- input_sockets_optional={'Expr': True},
+ inscks_kinds={'Expr': FK.Info},
+ input_sockets_optional={'Expr'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
- def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
- has_info = not ct.FlowSignal.check(input_sockets['Expr'])
-
- info_pending = ct.FlowSignal.check_single(
- input_sockets['Expr'], ct.FlowSignal.FlowPending
- )
+ def on_input_expr_changed(self, input_sockets) -> None: # noqa: D102
+ info = input_sockets['Expr']
+ has_info = not FS.check(info)
+ info_pending = FS.check_single(info, FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
- info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
- has_info = not ct.FlowSignal.check(info)
+ """Retrieve the input expression's `InfoFlow`."""
+ info = self._compute_input('Expr', kind=FK.Info)
+ has_info = not FS.check(info)
if has_info:
return info
return None
- operation: math_system.MapOperation = bl_cache.BLField(
+ ####################
+ # - Property: Operation
+ ####################
+ operation: MO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
+ """Retrieve valid operations based on the input `InfoFlow`."""
if self.expr_info is not None:
- return [
- operation.bl_enum_element(i)
- for i, operation in enumerate(
- math_system.MapOperation.by_expr_info(self.expr_info)
- )
- ]
+ return MO.bl_enum_elements(self.expr_info)
return []
####################
# - UI
####################
def draw_label(self):
+ """Show the current operation (if any) in the node's header label.
+
+ Notes:
+ Called by Blender to determine the text to place in the node's header.
+ """
if self.operation is not None:
- return 'Map: ' + math_system.MapOperation.to_name(self.operation)
+ return 'Map: ' + MO.to_name(self.operation)
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's properties inside of the node itself.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
layout.prop(self, self.blfields['operation'], text='')
####################
@@ -183,91 +196,81 @@ class MapMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
+ # Loaded
props={'operation'},
- input_sockets={'Expr'},
+ inscks_kinds={'Expr': FK.Value},
)
- def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
- operation = props['operation']
+ def compute_value(self, props, input_sockets) -> ct.ValueFlow | FS:
+ """Mapping operation on symbolic input expression."""
expr = input_sockets['Expr']
- has_expr_value = not ct.FlowSignal.check(expr)
-
- # Compute Sympy Function
- ## -> The operation enum directly provides the appropriate function.
- if has_expr_value and operation is not None:
+ operation = props['operation']
+ if operation is not None:
return operation.sp_func(expr)
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
+ kind=FK.Func,
# Loaded
- kind=ct.FlowKind.Func,
props={'operation'},
- input_sockets={'Expr'},
- input_socket_kinds={
- 'Expr': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Expr': FK.Func,
},
- output_sockets={'Expr'},
- output_socket_kinds={'Expr': ct.FlowKind.Info},
)
- def compute_func(
- self, props, input_sockets, output_sockets
- ) -> ct.FuncFlow | ct.FlowSignal:
- expr = input_sockets['Expr']
- output_info = output_sockets['Expr']
-
- has_expr = not ct.FlowSignal.check(expr)
- has_output_info = not ct.FlowSignal.check(output_info)
+ def compute_func(self, props, input_sockets) -> ct.FuncFlow | FS:
+ """Mapping operation on lazy-defined input expression."""
+ func = input_sockets['Expr']
operation = props['operation']
- if has_expr and operation is not None:
- return expr.compose_within(
- operation.jax_func,
- enclosing_func_output=output_info.output,
- supports_jax=True,
- )
- return ct.FlowSignal.FlowPending
+ if operation is not None:
+ return operation.transform_func(func)
+ return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
+ # Loaded
props={'operation'},
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
+ inscks_kinds={
+ 'Expr': FK.Info,
+ },
)
- def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
- operation = props['operation']
+ def compute_info(self, props, input_sockets) -> ct.InfoFlow:
+ """Transform the info chracterization of the input."""
info = input_sockets['Expr']
- has_info = not ct.FlowSignal.check(info)
-
- if has_info and operation is not None:
+ operation = props['operation']
+ if operation is not None:
return operation.transform_info(info)
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
+ # Loaded
+ props={'operation'},
input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Params},
+ input_socket_kinds={'Expr': FK.Params},
)
- def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
- has_params = not ct.FlowSignal.check(input_sockets['Expr'])
- if has_params:
- return input_sockets['Expr']
- return ct.FlowSignal.FlowPending
+ def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
+ """Transform the parameters of the input."""
+ params = input_sockets['Expr']
+
+ operation = props['operation']
+ if operation is not None:
+ return operation.transform_params(params)
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
index 920f863..1ebbfac 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/operate_math.py
@@ -24,6 +24,7 @@ import typing as typ
import bpy
from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
@@ -31,6 +32,11 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+BO = math_system.BinaryOperation
+MT = spux.MathType
+
class OperateMathNode(base.MaxwellSimNode):
r"""Applies a binary function between two expressions.
@@ -46,13 +52,11 @@ class OperateMathNode(base.MaxwellSimNode):
bl_label = 'Operate Math'
input_sockets: typ.ClassVar = {
- 'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
- 'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr L': sockets.ExprSocketDef(active_kind=FK.Func),
+ 'Expr R': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Func, show_info_columns=True
- ),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_info_columns=True),
}
####################
@@ -60,25 +64,21 @@ class OperateMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
- socket_name={'Expr L', 'Expr R'},
+ socket_name={'Expr L': FK.Info, 'Expr R': FK.Info},
# Loaded
- input_sockets={'Expr L', 'Expr R'},
- input_socket_kinds={'Expr L': ct.FlowKind.Info, 'Expr R': ct.FlowKind.Info},
- input_sockets_optional={'Expr L': True, 'Expr R': True},
+ inscks_kinds={'Expr L': FK.Info, 'Expr R': FK.Info},
+ input_sockets_optional={'Expr L', 'Expr R'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
- def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
- has_info_l = not ct.FlowSignal.check(input_sockets['Expr L'])
- has_info_r = not ct.FlowSignal.check(input_sockets['Expr R'])
+ def on_input_exprs_changed(self, input_sockets) -> None:
+ """Queue an update of the cached expression infos whenever data changed."""
+ has_info_l = not FS.check(input_sockets['Expr L'])
+ has_info_r = not FS.check(input_sockets['Expr R'])
- info_l_pending = ct.FlowSignal.check_single(
- input_sockets['Expr L'], ct.FlowSignal.FlowPending
- )
- info_r_pending = ct.FlowSignal.check_single(
- input_sockets['Expr R'], ct.FlowSignal.FlowPending
- )
+ info_l_pending = FS.check_single(input_sockets['Expr L'], FS.FlowPending)
+ info_r_pending = FS.check_single(input_sockets['Expr R'], FS.FlowPending)
if has_info_l and has_info_r and not info_l_pending and not info_r_pending:
self.expr_infos = bl_cache.Signal.InvalidateCache
@@ -86,21 +86,20 @@ class OperateMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None:
"""Computed `InfoFlow`s of both expressions."""
- info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info)
- info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info)
+ info_l = self._compute_input('Expr L', kind=FK.Info)
+ info_r = self._compute_input('Expr R', kind=FK.Info)
- has_info_l = not ct.FlowSignal.check(info_l)
- has_info_r = not ct.FlowSignal.check(info_r)
+ has_info_l = not FS.check(info_l)
+ has_info_r = not FS.check(info_r)
if has_info_l and has_info_r:
return (info_l, info_r)
-
return None
####################
# - Property: Operation
####################
- operation: math_system.BinaryOperation = bl_cache.BLField(
+ operation: BO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_infos'},
)
@@ -108,7 +107,7 @@ class OperateMathNode(base.MaxwellSimNode):
def search_operations(self) -> list[ct.BLEnumElement]:
"""Retrieve valid operations based on the input `InfoFlow`s."""
if self.expr_infos is not None:
- return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos)
+ return BO.bl_enum_elements(*self.expr_infos)
return []
####################
@@ -121,12 +120,12 @@ class OperateMathNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
- return 'Op: ' + math_system.BinaryOperation.to_name(self.operation)
+ return 'Op: ' + BO.to_name(self.operation)
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
- """Draw node properties in the node.
+ """Draw properties in the node.
Parameters:
col: UI target for drawing.
@@ -138,70 +137,59 @@ class OperateMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
+ # Loaded
props={'operation'},
- input_sockets={'Expr L', 'Expr R'},
- input_socket_kinds={
- 'Expr L': ct.FlowKind.Value,
- 'Expr R': ct.FlowKind.Value,
+ inscks_kinds={
+ 'Expr L': FK.Value,
+ 'Expr R': FK.Value,
},
)
- def compute_value(self, props: dict, input_sockets: dict):
+ def compute_value(self, props, input_sockets) -> ct.InfoFlow | FS:
"""Binary operation on two symbolic input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
- has_expr_l_value = not ct.FlowSignal.check(expr_l)
- has_expr_r_value = not ct.FlowSignal.check(expr_r)
-
operation = props['operation']
- if has_expr_l_value and has_expr_r_value and operation is not None:
+ if operation is not None:
return operation.sp_func([expr_l, expr_r])
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'operation'},
- input_sockets={'Expr L', 'Expr R'},
- input_socket_kinds={
- 'Expr L': ct.FlowKind.Func,
- 'Expr R': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Expr L': FK.Func,
+ 'Expr R': FK.Func,
},
- output_sockets={'Expr'},
- output_socket_kinds={'Expr': ct.FlowKind.Info},
)
- def compute_func(self, props, input_sockets, output_sockets):
+ def compute_func(self, props, input_sockets) -> ct.InfoFlow | FS:
"""Binary operation on two lazy-defined input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
- output_info = output_sockets['Expr']
-
- has_expr_l = not ct.FlowSignal.check(expr_l)
- has_expr_r = not ct.FlowSignal.check(expr_r)
- has_output_info = not ct.FlowSignal.check(output_info)
operation = props['operation']
- if operation is not None and has_expr_l and has_expr_r and has_output_info:
+ if operation is not None:
return self.operation.transform_funcs(expr_l, expr_r)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
+ # Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
- 'Expr L': ct.FlowKind.Info,
- 'Expr R': ct.FlowKind.Info,
+ 'Expr L': FK.Info,
+ 'Expr R': FK.Info,
},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
@@ -209,44 +197,43 @@ class OperateMathNode(base.MaxwellSimNode):
info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R']
- has_info_l = not ct.FlowSignal.check(info_l)
- has_info_r = not ct.FlowSignal.check(info_r)
+ has_info_l = not FS.check(info_l)
+ has_info_r = not FS.check(info_r)
operation = props['operation']
if (
- has_info_l and has_info_r and operation is not None
- # and operation in BO.by_infos(info_l, info_r)
+ has_info_l
+ and has_info_r
+ and operation is not None
+ and operation in BO.from_infos(info_l, info_r)
):
return operation.transform_infos(info_l, info_r)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
+ # Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
- 'Expr L': ct.FlowKind.Params,
- 'Expr R': ct.FlowKind.Params,
+ 'Expr L': FK.Params,
+ 'Expr R': FK.Params,
},
)
- def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Merge the lazy input parameters."""
params_l = input_sockets['Expr L']
params_r = input_sockets['Expr R']
- has_params_l = not ct.FlowSignal.check(params_l)
- has_params_r = not ct.FlowSignal.check(params_r)
-
operation = props['operation']
- if has_params_l and has_params_r and operation is not None:
- return params_l | params_r
-
- return ct.FlowSignal.FlowPending
+ if operation is not None:
+ return operation.transform_params(params_l, params_r)
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py
index c44e5a4..0800d57 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/reduce_math.py
@@ -14,126 +14,234 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Declares `ReduceMathNode`."""
+
+import enum
import typing as typ
import bpy
+import jax
import jax.numpy as jnp
-import sympy as sp
+import numpy as np
-from blender_maxwell.utils import logger
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
from .... import contracts as ct
-from .... import sockets
+from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+RO = math_system.ReduceOperation
class ReduceMathNode(base.MaxwellSimNode):
+ r"""Applies a function to the array as a whole, with arbitrary results.
+
+ The shape, type, and interpretation of the input/output data is dynamically shown.
+
+ Attributes:
+ operation: Operation to apply to the input.
+ """
+
node_type = ct.NodeType.ReduceMath
bl_label = 'Reduce Math'
input_sockets: typ.ClassVar = {
- 'Data': sockets.AnySocketDef(),
- 'Axis': sockets.IntegerNumberSocketDef(),
- }
- input_socket_sets: typ.ClassVar = {
- 'By Axis': {
- 'Axis': sockets.IntegerNumberSocketDef(),
- },
- 'Expr': {
- 'Reducer': sockets.ExprSocketDef(
- symbols=[sp.Symbol('a'), sp.Symbol('b')],
- default_expr=sp.Symbol('a') + sp.Symbol('b'),
- ),
- 'Axis': sockets.IntegerNumberSocketDef(),
- },
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_func_ui=False),
}
output_sockets: typ.ClassVar = {
- 'Data': sockets.AnySocketDef(),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
- # - Properties
+ # - Properties: Expr InfoFlow
####################
- operation: bpy.props.EnumProperty(
- name='Op',
- description='Operation to reduce the input axis with',
- items=lambda self, _: self.search_operations(),
- update=lambda self, context: self.on_prop_changed('operation', context),
+ @events.on_value_changed(
+ # Trigger
+ socket_name={'Expr': FK.Info},
+ # Loaded
+ inscks_kinds={'Expr': FK.Info},
+ input_sockets_optional={'Expr'},
+ # Flow
+ ## -> Expr wants to emit DataChanged, which is usually fine.
+ ## -> However, this node sets `expr_info`, which causes DC to emit.
+ ## -> One action should emit one DataChanged pipe.
+ ## -> Therefore, defer responsibility for DataChanged to self.expr_info.
+ # stop_propagation=True,
+ )
+ def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
+ has_info = not FS.check(input_sockets['Expr'])
+ info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
+
+ if has_info and not info_pending:
+ self.expr_info = bl_cache.Signal.InvalidateCache
+
+ @bl_cache.cached_bl_property()
+ def expr_info(self) -> ct.InfoFlow | None:
+ """Retrieve the input expression's `InfoFlow`."""
+ info = self._compute_input('Expr', kind=FK.Info)
+ has_info = not FS.check(info)
+ if has_info:
+ return info
+
+ return None
+
+ ####################
+ # - Properties: Operation
+ ####################
+ operation: RO = bl_cache.BLField(
+ enum_cb=lambda self, _: self.search_operations(),
+ cb_depends_on={'expr_info'},
)
- def search_operations(self) -> list[tuple[str, str, str]]:
- items = []
- if self.active_socket_set == 'By Axis':
- items += [
- # Accumulation
- ('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'),
- ('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'),
- ('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'),
- ('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'),
- ('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'),
- # Stats
- ('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'),
- ('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'),
- ('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'),
- ('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'),
- # Dimension Reduction
- ('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
- ]
- else:
- items += [('NONE', 'None', 'No operations...')]
+ def search_operations(self) -> list[ct.BLEnumElement]:
+ """Retrieve valid operations based on the input `InfoFlow`."""
+ if self.expr_info is not None:
+ return RO.bl_enum_elements(self.expr_info)
+ return []
- return items
+ ####################
+ # - Properties: Dimension Selection
+ ####################
+ active_dim: enum.StrEnum = bl_cache.BLField(
+ enum_cb=lambda self, _: self.search_dims(),
+ cb_depends_on={'operation', 'expr_info'},
+ )
+
+ def search_dims(self) -> list[ct.BLEnumElement]:
+ """Search valid dimensions for reduction."""
+ if self.expr_info is not None and self.operation is not None:
+ return [
+ (dim.name, dim.name_pretty, dim.name, '', i)
+ for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
+ ]
+ return []
+
+ @bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
+ def dim(self) -> sim_symbols.SimSymbol | None:
+ """Deduce the valid dimension."""
+ if self.expr_info is not None and self.active_dim is not None:
+ return self.expr_info.dim_by_name(self.active_dim, optional=True)
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'dim_0'})
+ def axis(self) -> int | None:
+ """The first currently active axis, derived from `self.dim_0`."""
+ if self.expr_info is not None and self.dim is not None:
+ return self.expr_info.dim_axis(self.dim)
+ return None
+
+ ####################
+ # - UI
+ ####################
+ def draw_label(self):
+ """Show the active reduce operation in the node's header label.
+
+ Notes:
+ Called by Blender to determine the text to place in the node's header.
+ """
+ if self.operation is not None:
+ if self.dim is not None:
+ return self.operation.name.replace('[a]', f'[{self.dim.name_pretty}]')
+ return self.operation.name
+ return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
- if self.active_socket_set != 'Axis Expr':
- layout.prop(self, 'operation')
+ """Draw the user interfaces of the node's properties inside of the node itself.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
+ layout.prop(self, self.blfields['operation'], text='')
+ layout.prop(self, self.blfields['active_dim'], text='')
####################
- # - Compute
+ # - Compute: Array
####################
@events.computes_output_socket(
- 'Data',
- props={'active_socket_set', 'operation'},
- input_sockets={'Data', 'Axis', 'Reducer'},
- input_socket_kinds={'Reducer': ct.FlowKind.Func},
- input_sockets_optional={'Reducer': True},
+ 'Expr',
+ kind=FK.Array,
+ # Loaded
+ outscks_kinds={
+ 'Expr': {FK.Func, FK.Params},
+ },
)
- def compute_data(self, props: dict, input_sockets: dict):
- if props['active_socket_set'] == 'By Axis':
- # Simple Accumulation
- if props['operation'] == 'SUM':
- return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'PROD':
- return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'MIN':
- return jnp.min(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'MAX':
- return jnp.max(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'P2P':
- return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis'])
+ def compute_array(self, output_sockets) -> ct.ArrayFlow | FS:
+ """Realize an `ArrayFlow` containing the array."""
+ array = events.realize_known(output_sockets['Expr'])
+ if array is not None:
+ return ct.ArrayFlow(
+ jax_bytes=(
+ array
+ if isinstance(array, np.ndarray | jax.Array)
+ else jnp.array(array)
+ ),
+ unit=output_sockets['Expr'][FK.Func].func_output.unit,
+ is_sorted=True,
+ )
+ return FS.FlowPending
- # Stats
- if props['operation'] == 'MEAN':
- return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'MEDIAN':
- return jnp.median(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'STDDEV':
- return jnp.std(input_sockets['Data'], axis=input_sockets['Axis'])
- if props['operation'] == 'VARIANCE':
- return jnp.var(input_sockets['Data'], axis=input_sockets['Axis'])
+ ####################
+ # - Compute: Func
+ ####################
+ @events.computes_output_socket(
+ 'Expr',
+ kind=FK.Func,
+ # Loaded
+ props={'operation'},
+ inscks_kinds={
+ 'Expr': FK.Func,
+ },
+ )
+ def compute_func(self, props, input_sockets) -> ct.FuncFlow | FS:
+ """Transform the input `FuncFlow` depending on the reduce operation."""
+ func = input_sockets['Expr']
- # Dimension Reduction
- if props['operation'] == 'SQUEEZE':
- return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
+ operation = props['operation']
+ if operation is not None:
+ return operation.transform_func(func)
+ return FS.FlowPending
- if props['active_socket_set'] == 'Expr':
- ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1)
- return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis'])
+ ####################
+ # - FlowKind.Info
+ ####################
+ @events.computes_output_socket(
+ 'Expr',
+ kind=FK.Info,
+ # Loaded
+ props={'operation', 'dim', 'expr_info'},
+ )
+ def compute_info(self, props) -> ct.InfoFlow | FS:
+ """Transform the input `InfoFlow` depending on the reduce operation."""
+ info = props['expr_info']
+ dim = props['dim']
- msg = 'Operation invalid'
- raise ValueError(msg)
+ operation = props['operation']
+ if operation is not None:
+ return operation.transform_info(info, dim)
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Expr',
+ kind=FK.Params,
+ # Loaded
+ props={'operation', 'axis'},
+ inscks_kinds={'Expr': FK.Params},
+ )
+ def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
+ """Transform the input `InfoFlow` depending on the reduce operation."""
+ params = input_sockets['Expr']
+ axis = props['axis']
+
+ operation = props['operation']
+ if operation is not None and axis is not None:
+ return operation.transform_params(params, axis)
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
index aa6ed07..fe03c46 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/transform_math.py
@@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
@@ -49,10 +52,10 @@ class TransformMathNode(base.MaxwellSimNode):
bl_label = 'Transform Math'
input_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@@ -62,9 +65,8 @@ class TransformMathNode(base.MaxwellSimNode):
# Trigger
socket_name={'Expr'},
# Loaded
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
- input_sockets_optional={'Expr': True},
+ inscks_kinds={'Expr': FK.Info},
+ input_sockets_optional={'Expr'},
# Flow
## -> Expr wants to emit DataChanged, which is usually fine.
## -> However, this node sets `expr_info`, which causes DC to emit.
@@ -73,18 +75,16 @@ class TransformMathNode(base.MaxwellSimNode):
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
- has_info = not ct.FlowSignal.check(input_sockets['Expr'])
- info_pending = ct.FlowSignal.check_single(
- input_sockets['Expr'], ct.FlowSignal.FlowPending
- )
+ has_info = not FS.check(input_sockets['Expr'])
+ info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
- info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
- has_info = not ct.FlowSignal.check(info)
+ info = self._compute_input('Expr', kind=FK.Info)
+ has_info = not FS.check(info)
if has_info:
return info
@@ -100,12 +100,7 @@ class TransformMathNode(base.MaxwellSimNode):
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None:
- return [
- operation.bl_enum_element(i)
- for i, operation in enumerate(
- math_system.TransformOperation.by_info(self.expr_info)
- )
- ]
+ return math_system.TransformOperation.bl_enum_elements(self.expr_info)
return []
####################
@@ -283,29 +278,27 @@ class TransformMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'operation', 'dim'},
input_sockets={'Expr'},
input_socket_kinds={
- 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
+ 'Expr': {FK.Func, FK.Info},
},
output_sockets={'Expr'},
- output_socket_kinds={'Expr': ct.FlowKind.Info},
+ output_socket_kinds={'Expr': FK.Info},
)
- def compute_func(
- self, props, input_sockets, output_sockets
- ) -> ct.FuncFlow | ct.FlowSignal:
+ def compute_func(self, props, input_sockets, output_sockets) -> ct.FuncFlow | FS:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = math_system.TransformOperation
- lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
- info = input_sockets['Expr'][ct.FlowKind.Info]
+ lazy_func = input_sockets['Expr'][FK.Func]
+ info = input_sockets['Expr'][FK.Info]
output_info = output_sockets['Expr']
- has_info = not ct.FlowSignal.check(info)
- has_lazy_func = not ct.FlowSignal.check(lazy_func)
- has_output_info = not ct.FlowSignal.check(output_info)
+ has_info = not FS.check(info)
+ has_lazy_func = not FS.check(lazy_func)
+ has_output_info = not FS.check(output_info)
operation = props['operation']
if operation is not None and has_lazy_func and has_info and has_output_info:
@@ -318,7 +311,7 @@ class TransformMathNode(base.MaxwellSimNode):
enclosing_func_output=output_info.output,
supports_jax=True,
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
case _:
return lazy_func.compose_within(
@@ -327,30 +320,28 @@ class TransformMathNode(base.MaxwellSimNode):
supports_jax=True,
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
# Loaded
props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
input_sockets={'Expr'},
- input_socket_kinds={
- 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
- },
+ input_socket_kinds={'Expr': {FK.Func, FK.Info, FK.Params}},
)
def compute_info( # noqa: PLR0911
self, props: dict, input_sockets: dict
- ) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
+ ) -> ct.InfoFlow | typ.Literal[FS.FlowPending]:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = math_system.TransformOperation
operation = props['operation']
- info = input_sockets['Expr'][ct.FlowKind.Info]
+ info = input_sockets['Expr'][FK.Info]
- has_info = not ct.FlowSignal.check(info)
+ has_info = not FS.check(info)
if has_info and operation is not None:
# Retrieve Properties
dim = props['dim']
@@ -359,11 +350,11 @@ class TransformMathNode(base.MaxwellSimNode):
new_physical_type = props['new_physical_type']
# Retrieve Expression Data
- lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
- params = input_sockets['Expr'][ct.FlowKind.Params]
+ lazy_func = input_sockets['Expr'][FK.Func]
+ params = input_sockets['Expr'][FK.Params]
- has_lazy_func = not ct.FlowSignal.check(lazy_func)
- has_params = not ct.FlowSignal.check(lazy_func)
+ has_lazy_func = not FS.check(lazy_func)
+ has_params = not FS.check(lazy_func)
# Match Pattern by Operation
match operation:
@@ -376,7 +367,7 @@ class TransformMathNode(base.MaxwellSimNode):
and new_unit in physical_type.valid_units
):
return operation.transform_info(info, dim=dim, unit=new_unit)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
case TO.FreqToVacWL if dim is not None and new_unit is not None and new_unit in spux.PhysicalType.Length.valid_units:
return operation.transform_info(info, dim=dim, unit=new_unit)
@@ -430,28 +421,28 @@ class TransformMathNode(base.MaxwellSimNode):
case TO.FT1D | TO.InvFT1D if dim is not None:
return operation.transform_info(info, dim=dim)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
props={'operation'},
input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Params},
+ input_socket_kinds={'Expr': FK.Params},
)
- def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
operation = props['operation']
params = input_sockets['Expr']
- has_params = not ct.FlowSignal.check(params)
+ has_params = not FS.check(params)
if has_params and operation is not None:
return params
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
index dda78ff..87fb726 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py
@@ -22,6 +22,7 @@ import jaxtyping as jtyp
import matplotlib.axis as mpl_ax
import sympy as sp
import sympy.physics.units as spu
+from frozendict import frozendict
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@@ -32,6 +33,9 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class VizMode(enum.StrEnum):
"""Available visualization modes.
@@ -218,7 +222,7 @@ class VizNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Func,
+ active_kind=FK.Func,
default_symbols=[sym_x_um],
default_value=sp.exp(-(x_um**2)),
),
@@ -239,26 +243,24 @@ class VizNode(base.MaxwellSimNode):
socket_name={'Expr'},
# Loaded
input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
+ input_socket_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr': True},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
- has_info = not ct.FlowSignal.check(input_sockets['Expr'])
+ has_info = not FS.check(input_sockets['Expr'])
- info_pending = ct.FlowSignal.check_single(
- input_sockets['Expr'], ct.FlowSignal.FlowPending
- )
+ info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
- info = self._compute_input('Expr', kind=ct.FlowKind.Info)
- if not ct.FlowSignal.check(info):
+ info = self._compute_input('Expr', kind=FK.Info)
+ if not FS.check(info):
return info
return None
@@ -349,32 +351,28 @@ class VizNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
- socket_name='Expr',
+ socket_name={'Expr': {FK.Info, FK.Params}},
run_on_init=True,
# Loaded
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
- input_sockets_optional={'Expr': True},
+ inscks_kinds={'Expr': {FK.Info, FK.Params}},
+ input_sockets_optional={'Expr'},
)
- def on_any_changed(self, input_sockets: dict):
- info = input_sockets['Expr'][ct.FlowKind.Info]
- params = input_sockets['Expr'][ct.FlowKind.Params]
-
- has_info = not ct.FlowSignal.check(info)
- has_params = not ct.FlowSignal.check(params)
+ def on_expr_changed(self, input_sockets: dict):
+ """Declare sockets for realizing all unknown symbols."""
+ info = input_sockets['Expr'][FK.Info]
+ params = input_sockets['Expr'][FK.Params]
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
- if has_info and has_params and params.symbols:
+ if params.symbols:
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
- 'active_kind': ct.FlowKind.Range
- if sym in info.dims
- else ct.FlowKind.Value
+ 'active_kind': FK.Value,
+ 'use_value_range_swapper': sym in info.dims,
}
)
)
@@ -389,25 +387,13 @@ class VizNode(base.MaxwellSimNode):
#####################
@events.computes_output_socket(
'Preview',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
- props={
- 'sim_node_name',
- 'viz_mode',
- 'viz_target',
- 'colormap',
- 'plot_width',
- 'plot_height',
- 'plot_dpi',
- },
- input_sockets={'Expr'},
- input_socket_kinds={
- 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
- },
- all_loose_input_sockets=True,
+ props={'sim_node_name', 'expr_info'},
)
- def compute_previews(self, props, input_sockets, loose_input_sockets):
+ def compute_previews(self, props):
"""Needed for the plot to regenerate in the viewer."""
+ # info = props['info']
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
#####################
@@ -424,22 +410,19 @@ class VizNode(base.MaxwellSimNode):
'plot_dpi',
},
input_sockets={'Expr'},
- input_socket_kinds={
- 'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
- },
+ input_socket_kinds={'Expr': {FK.Func, FK.Info, FK.Params}},
all_loose_input_sockets=True,
- stop_propagation=True,
)
def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets
) -> None:
log.debug('Show Plot')
- lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
- info = input_sockets['Expr'][ct.FlowKind.Info]
- params = input_sockets['Expr'][ct.FlowKind.Params]
+ lazy_func = input_sockets['Expr'][FK.Func]
+ info = input_sockets['Expr'][FK.Info]
+ params = input_sockets['Expr'][FK.Params]
- has_info = not ct.FlowSignal.check(info)
- has_params = not ct.FlowSignal.check(params)
+ has_info = not FS.check(info)
+ has_params = not FS.check(params)
plot = managed_objs['plot']
viz_mode = props['viz_mode']
@@ -452,11 +435,13 @@ class VizNode(base.MaxwellSimNode):
data = lazy_func.realize_as_data(
info,
params,
- symbol_values={
- sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
- },
+ symbol_values=frozendict(
+ {
+ sym: loose_input_sockets[sym.name]
+ for sym in params.sorted_symbols
+ }
+ ),
)
- ## TODO: CACHE entries that don't change, PLEASEEE
# Match Viz Type & Perform Visualization
## -> Viz Target determines how to plot.
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py
index 764f62d..839e8e1 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py
@@ -27,9 +27,9 @@ from collections import defaultdict
from types import MappingProxyType
import bpy
-import sympy as sp
from blender_maxwell.utils import bl_cache, bl_instance, logger
+from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
from .. import managed_objs as _managed_objs
@@ -39,6 +39,12 @@ from . import presets as _presets
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+FE = ct.FlowEvent
+MT = spux.MathType
+PT = spux.PhysicalType
+
####################
# - Types
####################
@@ -82,7 +88,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_sets: typ.ClassVar[dict[str, Sockets]] = MappingProxyType({})
output_socket_sets: typ.ClassVar[dict[str, Sockets]] = MappingProxyType({})
- managed_obj_types: typ.ClassVar[ManagedObjs] = MappingProxyType({})
+ managed_obj_types: typ.ClassVar[dict[str, ManagedObjs]] = MappingProxyType({})
presets: typ.ClassVar[dict[str, Preset]] = MappingProxyType({})
## __init_subclass__ Computed
@@ -168,11 +174,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
event_methods = [
method
for attr_name in dir(cls)
- if hasattr(method := getattr(cls, attr_name), 'event')
- and method.event in set(ct.FlowEvent)
- ## Forbidding blfields prevents triggering __get__ on bl_property
+ if hasattr(method := getattr(cls, attr_name), 'identifier')
+ and isinstance(method.identifier, str)
+ and method.identifier == events.EVENT_METHOD_IDENTIFIER
+ ## We must not trigger __get__ on any blfields here.
]
- event_methods_by_event = {event: [] for event in set(ct.FlowEvent)}
+ event_methods_by_event = {event: [] for event in set(FE)}
for method in event_methods:
event_methods_by_event[method.event].append(method)
@@ -216,16 +223,16 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# (Re)Construct Managed Objects
## -> Due to 'prev_name', the new MObjs will be renamed on construction
+ managed_objs = props['managed_objs']
+ managed_obj_types = props['managed_obj_types']
self.managed_objs = {
mobj_name: mobj_type(
- self.sim_node_name,
+ self.sim_node_name + (f'_{i}' if i > 0 else ''),
prev_name=(
- props['managed_objs'][mobj_name].name
- if mobj_name in props['managed_objs']
- else None
+ managed_objs[mobj_name].name if mobj_name in managed_objs else None
),
)
- for mobj_name, mobj_type in props['managed_obj_types'].items()
+ for i, (mobj_name, mobj_type) in enumerate(managed_obj_types.items())
}
@events.on_value_changed(prop_name='active_socket_set')
@@ -447,12 +454,18 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## Only Trigger: If method loads no socket data, it runs.
## Optional: If method optional-loads socket, it runs.
triggered_event_methods = [
- event_method
- for event_method in self.filtered_event_methods_by_event(
- ct.FlowEvent.DataChanged, (active_sckname, None, None)
+ value_changed_method
+ for value_changed_method in self.event_methods_by_event[
+ FE.DataChanged
+ ]
+ if value_changed_method.callback_info.should_run(
+ None,
+ active_sckname,
+ frozenset(FK),
+ active_sckname in self.loose_input_sockets,
)
- if active_sckname
- not in event_method.callback_info.must_load_sockets
+ and active_sckname
+ in value_changed_method.callback_info.optional_sockets_kinds
]
for event_method in triggered_event_methods:
event_method(self)
@@ -464,6 +477,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
+ unit_system=...,
)
# Update Sockets
@@ -511,6 +525,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
+ unit_system=...,
)
def _add_new_active_sockets(self):
@@ -575,88 +590,59 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
if i != current_idx_of_socket:
self.outputs.move(current_idx_of_socket, i)
- ####################
- # - Event Methods
- ####################
- @property
- def _event_method_filter_by_event(self) -> dict[ct.FlowEvent, typ.Callable]:
- """Compute a map of FlowEvents, to a function that filters its event methods.
-
- The returned filter functions are hard-coded, and must always return a `bool`.
- They may use attributes of `self`, always return `True` or `False`, or something different.
-
- Notes:
- This is an internal method; you probably want `self.filtered_event_methods_by_event`.
-
- Returns:
- The map of `ct.FlowEvent` to a function that can determine whether any `event_method` should be run.
- """
- return {
- ct.FlowEvent.EnableLock: lambda *_: True,
- ct.FlowEvent.DisableLock: lambda *_: True,
- ct.FlowEvent.DataChanged: lambda event_method, socket_name, prop_names, _: (
- (
- socket_name
- and socket_name in event_method.callback_info.on_changed_sockets
- )
- or (
- prop_names
- and any(
- prop_name in event_method.callback_info.on_changed_props
- for prop_name in prop_names
- )
- )
- or (
- socket_name
- and event_method.callback_info.on_any_changed_loose_input
- and socket_name in self.loose_input_sockets
- )
- ),
- # Non-Triggered
- ct.FlowEvent.OutputRequested: lambda output_socket_method,
- output_socket_name,
- _,
- kind: (
- kind == output_socket_method.callback_info.kind
- and (
- output_socket_name
- == output_socket_method.callback_info.output_socket_name
- )
- ),
- ct.FlowEvent.ShowPlot: lambda *_: True,
- }
-
- def filtered_event_methods_by_event(
- self,
- event: ct.FlowEvent,
- _filter: tuple[ct.SocketName, str],
- ) -> list[typ.Callable]:
- """Return all event methods that should run, given the context provided by `_filter`.
-
- The inclusion decision is made by the internal property `self._event_method_filter_by_event`.
-
- Returns:
- All `event_method`s that should run, as callable objects (they can be run using `event_method(self)`).
- """
- return [
- event_method
- for event_method in self.event_methods_by_event[event]
- if self._event_method_filter_by_event[event](event_method, *_filter)
- ]
-
####################
# - Compute: Input Socket
####################
@bl_cache.keyed_cache(
- exclude={'self', 'optional'},
- encode={'unit_system'},
+ exclude={'self'},
+ )
+ def compute_prop(
+ self,
+ prop_name: ct.PropName,
+ unit_system: spux.UnitSystem | None = None,
+ ) -> typ.Any:
+ """Computes the data of a property, with relevant unit system scaling.
+
+ Some properties return a `sympy` expression, which needs to be conformed to some unit system before it can be used.
+ For these cases,
+
+ When no unit system is in use, the cache of `compute_prop` is a transparent layer on top of the `BLField` cache, taking no extra memory.
+
+ Warnings:
+ **MUST** be invalidated whenever a property changed.
+ If not, then "phantom" values will be produced.
+
+ Parameters:
+ prop_name: The name of the property to compute the value of.
+ unit_system: The unit system to convert it to, if any.
+
+ Returns:
+ The property value, possibly scaled to the unit system.
+ """
+ # Retrieve Unit System and Property
+ if hasattr(self, prop_name):
+ prop_value = getattr(self, prop_name)
+ else:
+ msg = f'The node {self.sim_node_name} has no property {prop_name}.'
+ raise ValueError
+
+ if unit_system is not None:
+ if isinstance(prop_value, spux.SympyType):
+ return spux.scale_to_unit_system(prop_value)
+
+ msg = f'Cannot scale property {prop_name}={prop_value} (type={type(prop_value)} to a unit system, since it is not a sympy object (unit_system={unit_system})'
+ raise ValueError(msg)
+
+ return prop_value
+
+ @bl_cache.keyed_cache(
+ exclude={'self'},
)
def _compute_input(
self,
input_socket_name: ct.SocketName,
- kind: ct.FlowKind = ct.FlowKind.Value,
- unit_system: dict[ct.SocketType, sp.Expr] | None = None,
- optional: bool = False,
+ kind: FK = FK.Value,
+ unit_system: spux.UnitSystem | None = None,
) -> typ.Any:
"""Computes the data of an input socket, following links if needed.
@@ -667,15 +653,16 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_name: The name of the input socket to compute the value of.
It must be currently active.
kind: The data flow kind to compute.
+ unit_system: The unit system to scale the computed input socket to.
"""
bl_socket = self.inputs.get(input_socket_name)
if bl_socket is not None:
if bl_socket.instance_id:
- if kind is ct.FlowKind.Previews:
+ if kind is FK.Previews:
return bl_socket.compute_data(kind=kind)
return (
- ct.FlowKind.scale_to_unit_system(
+ FK.scale_to_unit_system(
kind,
bl_socket.compute_data(kind=kind),
unit_system,
@@ -687,91 +674,107 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# No Socket Instance ID
## -> Indicates that socket_def.preinit() has not yet run.
## -> Anyone needing results will need to wait on preinit().
- return ct.FlowSignal.FlowInitializing
+ return FS.FlowInitializing
- if kind is ct.FlowKind.Previews:
+ if kind is FK.Previews:
return ct.PreviewsFlow()
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
####################
# - Compute Event: Output Socket
####################
@bl_cache.keyed_cache(
- exclude={'self', 'optional'},
+ exclude={'self'},
)
def compute_output(
self,
output_socket_name: ct.SocketName,
- kind: ct.FlowKind = ct.FlowKind.Value,
- optional: bool = False,
- ) -> typ.Any:
+ kind: FK = FK.Value,
+ unit_system: spux.UnitSystem | None = None,
+ ) -> typ.Any | FS:
"""Computes the value of an output socket.
Parameters:
output_socket_name: The name declaring the output socket, for which this method computes the output.
kind: The FlowKind to use when computing the output socket value.
+ unit_system: The unit system to scale the computed output socket to.
Returns:
The value of the output socket, as computed by the dedicated method
registered using the `@computes_output_socket` decorator.
"""
- # Previews: Aggregate All Input Sockets
- ## -> All PreviewsFlows on all input sockets are combined.
- ## -> Output Socket Methods can add additional PreviewsFlows.
- if kind is ct.FlowKind.Previews:
- input_previews = functools.reduce(
- lambda a, b: a | b,
- [
- self._compute_input(
- socket_name,
- kind=ct.FlowKind.Previews,
- )
- for socket_name in [bl_socket.name for bl_socket in self.inputs]
- ],
- ct.PreviewsFlow(),
- )
-
- # No Output Socket: No Flow
- ## -> All PreviewsFlows on all input sockets are combined.
- ## -> Output Socket Methods can add additional PreviewsFlows.
- if self.outputs.get(output_socket_name) is None:
- return ct.FlowSignal.NoFlow
-
- output_socket_methods = self.filtered_event_methods_by_event(
- ct.FlowEvent.OutputRequested,
- (output_socket_name, None, kind),
+ log.debug(
+ '[%s] Computing Output (socket_name=%s, socket_kinds=%s, unit_system=%s)',
+ self.sim_node_name,
+ str(output_socket_name),
+ str(kind),
+ str(unit_system),
)
- # Exactly One Output Socket Method
- ## -> All PreviewsFlows on all input sockets are combined.
- ## -> Output Socket Methods can add additional PreviewsFlows.
- if len(output_socket_methods) == 1:
- res = output_socket_methods[0](self)
+ bl_socket = self.outputs.get(output_socket_name)
+ if bl_socket is not None:
+ if bl_socket.instance_id:
+ # Previews: Computed Aggregated Input Sockets
+ ## -> All sockets w/output get Previews from all inputs.
+ ## -> The user can also assign certain sockets.
+ if kind is FK.Previews:
+ input_previews = functools.reduce(
+ lambda a, b: a | b,
+ [
+ self._compute_input(
+ socket_name,
+ kind=FK.Previews,
+ )
+ for socket_name in [
+ bl_socket.name for bl_socket in self.inputs
+ ]
+ ],
+ ct.PreviewsFlow(),
+ )
- # Res is PreviewsFlow: Concatenate
- ## -> This will add the elements within the returned PreviewsFluw.
- if kind is ct.FlowKind.Previews and not ct.FlowSignal.check(res):
- input_previews |= res
+ # Retrieve Valid Output Socket Method
+ ## -> We presume that there is exactly one method per socket|kind.
+ ## -> We presume that there is exactly one method per socket|kind.
+ outsck_methods = [
+ method
+ for method in self.event_methods_by_event[FE.OutputRequested]
+ if method.callback_info.should_run(output_socket_name, kind)
+ ]
+ if len(outsck_methods) != 1:
+ if kind is FK.Previews:
+ return input_previews
+ return FS.NoFlow
- return res
+ outsck_method = outsck_methods[0]
- # > One Output Socket Method: Error
- if len(output_socket_methods) > 1:
- msg = (
- f'More than one method found for ({output_socket_name}, {kind.value!s}.'
- )
- raise RuntimeError(msg)
+ # Compute Flow w/Output Socket Method
+ flow = outsck_method(self)
+ has_flow = not FS.check(flow)
- if kind is ct.FlowKind.Previews:
- return input_previews
- return ct.FlowSignal.NoFlow
+ if kind is FK.Previews:
+ if has_flow:
+ return input_previews | flow
+ return input_previews
+
+ # *: Compute Flow
+ ## -> Perform unit-system scaling (maybe)
+ ## -> Otherwise, return flow (even if FlowSignal).
+ if has_flow and unit_system is not None:
+ return kind.scale_to_unit_system(flow, unit_system)
+ return flow
+
+ # No Socket Instance ID
+ ## -> Indicates that socket_def.preinit() has not yet run.
+ ## -> Anyone needing results will need to wait on preinit().
+ return FS.FlowInitializing
+ return FS.NoFlow
####################
# - Plot
####################
def compute_plot(self):
- plot_methods = self.filtered_event_methods_by_event(ct.FlowEvent.ShowPlot, ())
-
+ """Run all `on_show_plot` event methods."""
+ plot_methods = self.event_methods_by_event[FE.ShowPlot]
for plot_method in plot_methods:
plot_method(self)
@@ -782,8 +785,8 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self,
method_info: events.InfoOutputRequested,
input_socket_name: ct.SocketName | None,
- input_socket_kinds: set[ct.FlowKind] | None,
- prop_names: set[str] | None,
+ input_socket_kinds: set[FK] | None,
+ prop_names: set[ct.PropName] | None,
) -> bool:
return (
prop_names is not None
@@ -795,14 +798,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_kinds is None
or (
isinstance(
- _kind := method_info.depon_input_socket_kinds.get(
- input_socket_name, ct.FlowKind.Value
+ _kind := method_info.depon_input_sockets_kinds.get(
+ input_socket_name, FK.Value
),
set,
)
and input_socket_kinds.intersection(_kind)
)
- or _kind == ct.FlowKind.Value
+ or _kind == FK.Value
or _kind in input_socket_kinds
)
or (
@@ -815,9 +818,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
@bl_cache.cached_bl_property()
def output_socket_invalidates(
self,
- ) -> dict[
- tuple[ct.SocketName, ct.FlowKind], set[tuple[ct.SocketName, ct.FlowKind]]
- ]:
+ ) -> dict[tuple[ct.SocketName, FK], set[tuple[ct.SocketName, FK]]]:
"""Deduce which output socket | `FlowKind` combos are altered in response to a given output socket | `FlowKind` combo.
Returns:
@@ -830,9 +831,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# Iterate ALL Methods that Compute Output Sockets
## -> We call it the "altered method".
## -> Our approach will be to deduce what relies on it.
- output_requested_methods = self.event_methods_by_event[
- ct.FlowEvent.OutputRequested
- ]
+ output_requested_methods = self.event_methods_by_event[FE.OutputRequested]
for altered_method in output_requested_methods:
altered_info = altered_method.callback_info
altered_key = (altered_info.output_socket_name, altered_info.kind)
@@ -859,12 +858,15 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
is_same_kind = (
altered_info.kind
is (
- _kind := invalidated_info.depon_output_socket_kinds.get(
+ _kind := invalidated_info.depon_output_sockets_kinds.get(
altered_info.output_socket_name
)
)
- or (isinstance(_kind, set) and altered_info.kind in _kind)
- or altered_info.kind is ct.FlowKind.Value
+ or (
+ isinstance(_kind, set | frozenset)
+ and altered_info.kind in _kind
+ )
+ or altered_info.kind is FK.Value
)
# Check Success: Add Invalidated (name,kind) to Altered Set
@@ -881,14 +883,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
def trigger_event(
self,
- event: ct.FlowEvent,
+ event: FE,
socket_name: ct.SocketName | None = None,
- socket_kinds: set[ct.FlowKind] | None = None,
- prop_names: set[str] | None = None,
+ socket_kinds: set[FK] | None = None,
+ prop_names: set[ct.PropName] | None = None,
) -> None:
"""Recursively triggers events forwards or backwards along the node tree, allowing nodes in the update path to react.
- Use `events` decorators to define methods that react to particular `ct.FlowEvent`s.
+ Use `events` decorators to define methods that react to particular `FE`s.
Notes:
This can be an unpredictably heavy function, depending on the node graph topology.
@@ -901,6 +903,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
socket_name: The input socket that was altered, if any, in order to trigger this event.
pop_name: The property that was altered, if any, in order to trigger this event.
"""
+ if socket_kinds is not None:
+ socket_kinds = frozenset(socket_kinds)
+ if prop_names is not None:
+ prop_names = frozenset(prop_names)
+
+ ## -> Track actual alterations per output socket|kind.
+ altered_outscks_kinds: dict[ct.SocketName, set[FK]] = defaultdict(set)
+
# log.debug(
# '[%s] [%s] Triggered (socket_name=%s, socket_kinds=%s, prop_names=%s)',
# self.sim_node_name,
@@ -910,15 +920,11 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# str(prop_names),
# )
- # Invalidate Caches on DataChanged
- ## -> socket_kinds MUST NOT be None
- ## -> Trigger direction is always 'forwards' for DataChanged
- ## -> Track which FlowKinds are actually altered per-output-socket.
- altered_socket_kinds: dict[ct.SocketName, set[ct.FlowKind]] = defaultdict(set)
- if event is ct.FlowEvent.DataChanged:
+ # Event: DataChanged
+ if event is FE.DataChanged:
in_sckname = socket_name
- # Clear Input Socket Cache(s)
+ # Input Socket: Clear Cache(s)
## -> The input socket cache for each altered FlowKinds is cleared.
## -> Since it's non-persistent, it will be lazily re-filled.
if in_sckname is not None:
@@ -935,80 +941,76 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
unit_system=...,
)
- # Clear Output Socket Cache(s)
- for output_socket_method in self.event_methods_by_event[
- ct.FlowEvent.OutputRequested
- ]:
+ # Output Sockets: Clear Cache(s)
+ for output_socket_method in self.event_methods_by_event[FE.OutputRequested]:
# Determine Consequences of Changed (Socket|Kind) / Prop
## -> Each '@computes_output_socket' declares data to load.
- ## -> Compare what was changed to what each output socket needs.
- ## -> IF what is needed, was changed, THEN:
- ## --- The output socket needs recomputing.
+ ## -> Ask if the altered data should cause it to reload.
method_info = output_socket_method.callback_info
- if self._should_recompute_output_socket(
- method_info, socket_name, socket_kinds, prop_names
+ if method_info.should_recompute(
+ prop_names,
+ socket_name,
+ socket_kinds,
+ socket_name in self.loose_input_sockets,
):
out_sckname = method_info.output_socket_name
out_kind = method_info.kind
- # log.debug(
- # '![%s] Clear Output Socket Cache (%s, %s)',
- # self.sim_node_name,
- # out_sckname,
- # out_kind,
- # )
self.compute_output.invalidate(
output_socket_name=out_sckname,
kind=out_kind,
+ unit_system=...,
)
- altered_socket_kinds[out_sckname].add(out_kind)
+ altered_outscks_kinds[out_sckname].add(out_kind)
- # Invalidate Dependent Output Sockets
- ## -> Other outscks may depend on the altered outsck.
- ## -> The property 'output_socket_invalidates' encodes this.
- ## -> The property 'output_socket_invalidates' encodes this.
+ # Recurse: Output-Output Dependencies
+ ## -> Outscks may depend on each other.
+ ## -> Pre-build an invalidation map per-node.
cleared_outscks_kinds = self.output_socket_invalidates.get(
(out_sckname, out_kind)
)
- if cleared_outscks_kinds is not None:
+ if cleared_outscks_kinds:
for dep_out_sckname, dep_out_kind in cleared_outscks_kinds:
- # log.debug(
- # '!![%s] Clear Output Socket Cache (%s, %s)',
- # self.sim_node_name,
- # out_sckname,
- # out_kind,
- # )
self.compute_output.invalidate(
output_socket_name=dep_out_sckname,
kind=dep_out_kind,
+ unit_system=...,
)
- altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
+ altered_outscks_kinds[dep_out_sckname].add(dep_out_kind)
- # Clear Output Socket Cache(s)
- ## -> We aggregate it manually, so it needs a special invl.
- ## -> See self.compute_output()
- if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
+ # Any Preview Change -> All Output Previews Regenerate
+ ## -> All output sockets aggregate all input socket previews.
+ if socket_kinds is not None and FK.Previews in socket_kinds:
for out_sckname in self.outputs.keys(): # noqa: SIM118
self.compute_output.invalidate(
output_socket_name=out_sckname,
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
+ unit_system=...,
)
- altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
+ altered_outscks_kinds[out_sckname].add(FK.Previews)
- # Run Triggered Event Methods
- ## -> A triggered event method may request to stop propagation.
- stop_propagation = False
- triggered_event_methods = self.filtered_event_methods_by_event(
- event, (socket_name, prop_names, None)
- )
- for event_method in triggered_event_methods:
- stop_propagation |= event_method.stop_propagation
- # log.debug(
- # '![%s] Running: %s',
- # self.sim_node_name,
- # str(event_method.callback_info),
- # )
- event_method(self)
+ # Run 'on_value_changed' Callbacks
+ ## -> These event methods specifically respond to DataChanged.
+ stop_propagation = False
+ for value_changed_method in (
+ method
+ for method in self.event_methods_by_event[FE.DataChanged]
+ if method.callback_info.should_run(
+ prop_names,
+ socket_name,
+ socket_kinds,
+ socket_name in self.loose_input_sockets,
+ )
+ ):
+ stop_propagation |= value_changed_method.callback_info.stop_propagation
+ value_changed_method(self)
+
+ if stop_propagation:
+ return
+
+ elif event is FE.EnableLock or event is FE.DisableLock:
+ for lock_method in self.event_methods_by_event[event]:
+ lock_method(self)
# Propagate Event
## -> If 'stop_propagation' was tripped, don't propagate.
@@ -1016,36 +1018,35 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Each FlowEvent decides whether to flow forwards/backwards.
## -> The trigger chain goes node/socket/socket/node/socket/...
## -> Unlinked sockets naturally stop the propagation.
- if not stop_propagation:
- direc = ct.FlowEvent.flow_direction[event]
- for bl_socket in self._bl_sockets(direc=direc):
- # DataChanged: Propagate Altered SocketKinds
- ## -> Only altered FlowKinds for the socket will propagate.
- ## -> In this way, we guarantee no extraneous (noop) flow.
- if event is ct.FlowEvent.DataChanged:
- if bl_socket.name in altered_socket_kinds:
- # log.debug(
- # '![%s] [%s] Propagating (direction=%s, altered_socket_kinds=%s)',
- # self.sim_node_name,
- # event,
- # direc,
- # altered_socket_kinds[bl_socket.name],
- # )
- bl_socket.trigger_event(
- event, socket_kinds=altered_socket_kinds[bl_socket.name]
- )
-
- ## -> Otherwise, do nothing - guarantee no extraneous flow.
-
- # Propagate Normally
- else:
+ direc = FE.flow_direction[event]
+ for bl_socket in self._bl_sockets(direc=direc):
+ # DataChanged: Propagate Altered SocketKinds
+ ## -> Only altered FlowKinds for the socket will propagate.
+ ## -> In this way, we guarantee no extraneous (noop) flow.
+ if event is FE.DataChanged:
+ if bl_socket.name in altered_outscks_kinds:
# log.debug(
- # '![%s] [%s] Propagating (direction=%s)',
+ # '![%s] [%s] Propagating (direction=%s, altered_socket_kinds=%s)',
# self.sim_node_name,
# event,
# direc,
+ # altered_socket_kinds[bl_socket.name],
# )
- bl_socket.trigger_event(event)
+ bl_socket.trigger_event(
+ event, socket_kinds=altered_outscks_kinds[bl_socket.name]
+ )
+
+ ## -> Otherwise, do nothing - guarantee no extraneous flow.
+
+ # Propagate Normally
+ else:
+ # log.debug(
+ # '![%s] [%s] Propagating (direction=%s)',
+ # self.sim_node_name,
+ # event,
+ # direc,
+ # )
+ bl_socket.trigger_event(event)
####################
# - Property Event: On Update
@@ -1068,13 +1069,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
if prop_name in self.blfields:
cleared_blfields = self.clear_blfields_after(prop_name)
+ for prop_name, _ in cleared_blfields:
+ self.compute_prop.invalidate(
+ prop_name=prop_name,
+ unit_system=...,
+ )
+
# log.debug(
# '%s (Node): Set of Cleared BLFields: %s',
# self.bl_label,
# str(cleared_blfields),
# )
self.trigger_event(
- ct.FlowEvent.DataChanged,
+ FE.DataChanged,
prop_names={prop_name for prop_name, _ in cleared_blfields},
)
@@ -1204,7 +1211,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Compromise: Users explicitly say 'run_on_init' in @on_value_changed
for event_method in [
event_method
- for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
+ for event_method in self.event_methods_by_event[FE.DataChanged]
if event_method.callback_info.run_on_init
]:
event_method(self)
@@ -1244,7 +1251,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Copying a node ~ re-initializing the new node.
for event_method in [
event_method
- for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
+ for event_method in self.event_methods_by_event[FE.DataChanged]
if event_method.callback_info.run_on_init
]:
event_method(self)
@@ -1271,7 +1278,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
bl_socket.is_linked and bl_socket.locked
for bl_socket in self.inputs.values()
):
- self.trigger_event(ct.FlowEvent.DisableLock)
+ self.trigger_event(FE.DisableLock)
# Free Managed Objects
for managed_obj in self.managed_objs.values():
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
index 9a94de8..354125d 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/events.py
@@ -14,82 +14,439 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import dataclasses
+import functools
import inspect
import typing as typ
+import uuid
+from collections import defaultdict
+from fractions import Fraction
from types import MappingProxyType
-from blender_maxwell.utils import sympy_extra as spux
+import bpy
+import jax
+import numpy as np
+import pydantic as pyd
+import sympy as sp
+
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import FrozenDict, frozendict
+from blender_maxwell.utils.lru_method import method_lru
from .. import contracts as ct
log = logger.get(__name__)
+ManagedObjName: typ.TypeAlias = str
UnitSystemID = str
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+EVENT_METHOD_IDENTIFIER: str = str(uuid.uuid4()) ## Changes on startup.
+
####################
# - Event Callback Information
####################
-@dataclasses.dataclass(kw_only=True, frozen=True)
-class InfoDataChanged:
+class CallbackInfo(pyd.BaseModel):
+ """Base for information associated with an event method callback."""
+
+ model_config = pyd.ConfigDict(frozen=True)
+
+ # def parse_for_method(self, node):
+
+
+class InfoDataChanged(CallbackInfo):
+ """Information for determining whether a particular `DataChanged` callback should run."""
+
+ model_config = pyd.ConfigDict(frozen=True)
+
+ on_changed_props: frozenset[ct.PropName]
+ on_changed_sockets_kinds: FrozenDict[ct.SocketName, frozenset[FK]]
+ on_any_changed_loose_socket: bool
run_on_init: bool
- on_changed_sockets: set[ct.SocketName]
- on_changed_props: set[str]
- on_any_changed_loose_input: set[str]
- must_load_sockets: set[str]
+
+ optional_sockets: frozenset[ct.SocketName]
+ stop_propagation: bool = False
+
+ ####################
+ # - Computed Properties
+ ####################
+ @functools.cached_property
+ def on_changed_sockets(self) -> frozenset[ct.SocketName]:
+ """Input sockets with a `FlowKind` for which the method will run."""
+ return frozenset(self.on_changed_sockets_kinds.keys())
+
+ @functools.cached_property
+ def optional_sockets_kinds(self) -> frozendict[ct.SocketName, frozenset[FK]]:
+ """Input `socket|kind`s for which the method can run, even when only a `FlowSignal` is available."""
+ return {
+ changed_socket: kinds
+ for changed_socket, kinds in self.on_changed_sockets_kinds
+ if changed_socket in self.optional_sockets
+ }
+
+ ####################
+ # - Methods
+ ####################
+ @method_lru(maxsize=2048)
+ def should_run(
+ self,
+ changed_props: frozenset[str] | None,
+ changed_socket: ct.SocketName | None,
+ changed_kinds: frozenset[FK] | None = None,
+ socket_is_loose: bool = False,
+ ):
+ """Deduce whether this method should run in response to a particular set of changed inputs."""
+ prop_triggered = changed_props is not None and any(
+ changed_prop in self.on_changed_props for changed_prop in changed_props
+ )
+
+ socket_triggered = (
+ changed_socket is not None
+ and changed_kinds is not None
+ and changed_socket in self.on_changed_sockets
+ and any(
+ changed_kind in self.on_changed_sockets_kinds[changed_socket]
+ for changed_kind in changed_kinds
+ )
+ )
+
+ loose_socket_triggered = (
+ socket_is_loose
+ and changed_socket is not None
+ and self.on_any_changed_loose_socket
+ )
+ return socket_triggered or prop_triggered or loose_socket_triggered
-@dataclasses.dataclass(kw_only=True, frozen=True)
-class InfoOutputRequested:
+class InfoOutputRequested(CallbackInfo):
+ """Information for determining which output socket method should run."""
+
+ model_config = pyd.ConfigDict(frozen=True)
+
output_socket_name: ct.SocketName
- kind: ct.FlowKind
+ kind: FK
- depon_props: set[str]
-
- depon_input_sockets: set[ct.SocketName]
- depon_input_socket_kinds: dict[ct.SocketName, ct.FlowKind | set[ct.FlowKind]]
+ depon_props: frozenset[str]
+ depon_input_sockets_kinds: FrozenDict[ct.SocketName, FK | frozenset[FK]]
+ depon_output_sockets_kinds: FrozenDict[ct.SocketName, FK | frozenset[FK]]
depon_all_loose_input_sockets: bool
-
- depon_output_sockets: set[ct.SocketName]
- depon_output_socket_kinds: dict[ct.SocketName, ct.FlowKind | set[ct.FlowKind]]
depon_all_loose_output_sockets: bool
+ ####################
+ # - Computed Properties
+ ####################
+ @functools.cached_property
+ def depon_input_sockets(self) -> frozenset[ct.SocketName]:
+ """The input sockets depended on by this output socket method."""
+ return frozenset(self.depon_input_sockets_kinds.keys())
+
+ @functools.cached_property
+ def depon_output_sockets(self) -> frozenset[ct.SocketName]:
+ """The output sockets depended on by this output socket method."""
+ return frozenset(self.depon_output_sockets_kinds.keys())
+
+ ####################
+ # - Methods
+ ####################
+ @method_lru(maxsize=2048)
+ def should_run(
+ self,
+ requested_socket: ct.SocketName,
+ requested_kind: FK,
+ ):
+ """Deduce whether this method can compute the requested socket and kind."""
+ return (
+ requested_kind is self.kind and requested_socket == self.output_socket_name
+ )
+
+ @method_lru(maxsize=2048)
+ def should_recompute(
+ self,
+ changed_props: frozenset[str] | None,
+ changed_socket: ct.SocketName | None,
+ changed_kinds: frozenset[FK] | None = None,
+ socket_is_loose: bool = False,
+ ):
+ """Deduce whether this method needs to be recomputed after a change in a particular set of changed inputs."""
+ prop_altered = changed_props is not None and any(
+ changed_prop in self.depon_props for changed_prop in changed_props
+ )
+
+ socket_altered = (
+ changed_socket is not None
+ and changed_kinds is not None
+ and changed_socket in self.depon_input_sockets
+ and any(
+ kind in self.depon_input_sockets_kinds[changed_socket]
+ for kind in changed_kinds
+ )
+ )
+
+ loose_socket_altered = (
+ socket_is_loose
+ and changed_socket is not None
+ and self.depon_all_loose_input_sockets
+ )
+
+ return prop_altered or socket_altered or loose_socket_altered
+
EventCallbackInfo: typ.TypeAlias = InfoDataChanged | InfoOutputRequested
####################
-# - Event Decorator
+# - Node Parsers
####################
-ManagedObjName: typ.TypeAlias = str
-PropName: typ.TypeAlias = str
+def parse_node_mobjs(node: bpy.types.Node, mobjs: frozenset[ManagedObjName]) -> typ.Any:
+ """Retrieve the given managed objects."""
+ return {mobj_name: node.managed_objs[mobj_name] for mobj_name in mobjs}
-def event_decorator( # noqa: PLR0913
+def parse_node_props(
+ node: bpy.types.Node,
+ props: frozenset[ct.PropName],
+ prop_unit_systems: frozendict[ct.PropName, spux.UnitSystem | None],
+) -> typ.Any:
+ """Compute the values of the given property names, w/optional scaling to a unit system.
+
+ Raises:
+ ValueError: If a unit system is specified for the property value, but the property value is not a `sympy` type.
+ """
+ return frozendict(
+ {
+ prop: node.compute_prop(prop, unit_system=prop_unit_systems.get(prop))
+ for prop in props
+ }
+ )
+
+
+def parse_node_sck(
+ node: bpy.types.Node,
+ direc: typ.Literal['input', 'output'],
+ sck: ct.SocketName,
+ _kind: FK | None,
+ unit_system: spux.UnitSystem | None,
+) -> typ.Any:
+ """Compute a single value for `sck|kind|unit_system`.
+
+ Parameters:
+ node: The node to parse a socket value from.
+ direc: Whether the socket to parse is an input or output socket.
+ sck: The name of the socket to parse.
+ _kind: The `FlowKind` of the socket to parse.
+ When `None`, use the `.active_kind` attribute of the socket to determine what to parse.
+ unit_system: The unit system with which to scale the socket value.
+
+ Returns:
+ The value of the socket over the given `FlowKind` lane, potentially scaled to a unitless scalar (if requested) in a manner specific to the `FlowKind`.
+ """
+ # Deduce Kind
+ ## -> _kind=None denotes "use active_kind of socket".
+ kind = node._bl_sockets(direc=direc)[sck].active_kind if _kind is None else _kind # noqa: SLF001
+
+ # Compute Socket Value
+ if direc == 'input':
+ return node._compute_input(sck, kind=kind, unit_system=unit_system) # noqa: SLF001
+ if direc == 'output':
+ return node.compute_output(sck, kind=kind, unit_system=unit_system)
+
+ raise TypeError
+
+
+def parse_node_scks_kinds(
+ node: bpy.types.Node,
+ direc: typ.Literal['input', 'output'],
+ scks_kinds: frozendict[ct.SocketName, frozenset[FK] | FK | None],
+ scks_unit_system: spux.UnitSystem
+ | frozendict[ct.SocketName, spux.UnitSystem | None],
+) -> (
+ frozendict[ct.SocketName, typ.Any]
+ | frozendict[ct.SocketName, frozenset[typ.Any]]
+ | None
+):
+ """Retrieve the values for given input sockets and kinds, w/optional scaling to a unit system.
+
+ In general, unless the socket name is specified in `scks_optional`, then whenever `FlowSignal` is encountered while computing sockets, the function will return `None` immediately.
+ This process is "short-circuit", which is a partial optimization causing an immediate return before any other computing any other sockets.
+ """
+ # Compute Socket Values
+ ## -> Every time we run `compute()`, we might encounter a FlowSignal.
+ ## -> If so, and the socket is 'optional', we let it be.
+ ## -> If not, and the socket is not 'optional', we return immediately.
+ computed_scks = {}
+ for sck, kinds in scks_kinds.items():
+ # Extract Unit System
+ if isinstance(scks_unit_system, dict | frozendict):
+ unit_system = scks_unit_system.get(sck)
+ else:
+ unit_system = scks_unit_system
+
+ flow_values = {}
+ for kind in kinds if isinstance(kinds, set | frozenset) else {kinds}:
+ flow = parse_node_sck(node, direc, sck, kind, unit_system)
+ flow_values[kind] = flow
+
+ if len(flow_values) == 1:
+ computed_scks[sck] = next(iter(flow_values.values()))
+ else:
+ computed_scks[sck] = frozendict(flow_values)
+
+ return frozendict(computed_scks)
+
+
+####################
+# - Utilities
+####################
+def freeze_pytree(
+ pytree: None
+ | str
+ | int
+ | Fraction
+ | float
+ | complex
+ | set
+ | frozenset
+ | list
+ | tuple
+ | dict
+ | frozendict,
+):
+ """Conform an arbitrarily nested containers into their immutable equivalent."""
+ if isinstance(pytree, set | frozenset):
+ return frozenset(freeze_pytree(el) for el in pytree)
+ if isinstance(pytree, list | tuple):
+ return tuple(freeze_pytree(el) for el in pytree)
+ if isinstance(pytree, np.ndarray | jax.Array):
+ return tuple(freeze_pytree(el) for el in pytree.tolist())
+ if isinstance(pytree, dict | frozendict):
+ return frozendict({k: freeze_pytree(v) for k, v in pytree.items()})
+
+ if isinstance(pytree, None | str | int | Fraction | float | complex):
+ return pytree
+
+ raise TypeError
+
+
+def realize_known(
+ sck: frozendict[typ.Literal[FK.Func, FK.Params], ct.FuncFlow | ct.ParamsFlow],
+ freeze: bool = False,
+ conformed: bool = False,
+) -> int | float | tuple[int | float] | None:
+ """Realize a concrete preview-value from a `FuncFlow` and a `ParamsFlow`, when there are no unrealized symbols in `ParamsFlow`.
+
+ It often happens that we absolutely need an _unrealized_ value for a node to work, eg. when producing a `Value` from the middle of a fully-lazy chain.
+ Several complications arise when doing so, not least of which is how to handle the case where there are still-unrealized symbols.
+
+ This method encapsulates all of these complexities into a single call, whose availability can be handled with a simple `None`-check.
+
+ Parameters:
+ sck: Mapping from dictionaries
+
+ Examples:
+ Within an event method depending on a socket `Socket: {FK.Func, FK.Params, ...}`, a realized value can
+ Generally accessed by calling `event.realize_known
+
+ Returns `None` when there are unrealized symbols, or either `Func` or `Params` is a `FlowSignal`.
+ """
+ has_func = not FS.check(sck[FK.Func])
+ has_params = not FS.check(sck[FK.Params])
+
+ if has_func and has_params:
+ realized = sck[FK.Func].realize(sck[FK.Params])
+ if freeze:
+ return freeze_pytree(realized)
+ if conformed:
+ func_output = sck[FK.Func].func_output
+ if func_output is not None:
+ return func_output.conform(realized)
+ return realized
+ return realized
+ return None
+
+
+def realize_preview(
+ sck: frozendict[typ.Literal[FK.Func, FK.Params], ct.FuncFlow | ct.ParamsFlow],
+) -> int | float | tuple[int | float]:
+ """Realize a concrete preview-value from a `FuncFlow` and a `ParamsFlow`.
+
+ This particular operation is widely used in `on_value_changed` methods that update previews, since they must intercept both the function and parameter flows in order to respect ex. partially relized symbols and units.
+ Usually, when such a thing happens, we support it in `event_decorator`.
+ But in this case, the designs required were not resonating quite right - adding either too much specific complexity to input/output/loose fields, requiring too much "magic", etc. .
+
+ Why not just ask users to intercept intercept the `Value` output?
+ Several reasons.
+ Firstly, it alone _absolutely cannot_ handle unrealized symbols, which while reasonable for a structure that should be **fully realized**, is not quite the desired functionality with preview-oriented workflows: In particular, we want we want to use the the stand-in `SimSymbol.preview_value_phy`, even though it's not always "accurate".
+ Secondly, constructing the full `Value` output is slow, and introduces superfluous preview-dependencies.
+
+ In this situation, the best balance is to provide this utility function.
+ The user needs to deconstruct the eg. `input_sockets` parameter _anyway_ at least once; with this function, what would otherwise be an unweildy piece of realization logic is cleanly encapsulated.
+ """
+ return sck[FK.Func].realize_preview(sck[FK.Params])
+
+
+def mk_sockets_kinds(
+ sockets: dict[ct.SocketName, frozenset[FK]]
+ | dict[ct.SocketName, FK]
+ | ct.SocketName
+ | None,
+ default_kinds: frozenset[FK] = frozenset(FK),
+) -> frozendict[ct.SocketName, frozenset[FK]]:
+ """Normalize the given parameters to a standardized type."""
+ # Deduce Triggered Socket -> SocketKinds
+ ## -> Normalize all valid inputs to frozendict[SocketName, set[FlowKind]].
+ if sockets is None:
+ return {}
+ if isinstance(sockets, dict | frozendict):
+ sockets_kinds = {
+ socket: (
+ kinds if isinstance(kinds := _kinds, set | frozenset) else {_kinds}
+ )
+ for socket, _kinds in sockets.items()
+ }
+ else:
+ sockets_kinds = {
+ socket: default_kinds
+ for socket in (
+ sockets if isinstance(sockets, set | frozenset) else [sockets]
+ )
+ }
+
+ return frozendict(sockets_kinds)
+
+
+####################
+# - General Event Callbacks
+####################
+def event_decorator( # noqa: C901, PLR0913, PLR0915
event: ct.FlowEvent,
callback_info: EventCallbackInfo | None,
- stop_propagation: bool = False,
- # Request Data for Callback
- managed_objs: set[ManagedObjName] = frozenset(),
- props: set[PropName] = frozenset(),
- input_sockets: set[ct.SocketName] = frozenset(),
- input_sockets_optional: dict[ct.SocketName, bool] = MappingProxyType({}),
- input_socket_kinds: dict[
- ct.SocketName, ct.FlowKind | set[ct.FlowKind]
- ] = MappingProxyType({}),
- output_sockets: set[ct.SocketName] = frozenset(),
- output_sockets_optional: dict[ct.SocketName, bool] = MappingProxyType({}),
- output_socket_kinds: dict[
- ct.SocketName, ct.FlowKind | set[ct.FlowKind]
- ] = MappingProxyType({}),
+ # Loading: Internal Data
+ managed_objs: frozenset[ManagedObjName] = frozenset(),
+ props: frozenset[ct.PropName] = frozenset(),
+ # Loading: Input Sockets
+ input_sockets: frozenset[ct.SocketName] = frozenset(),
+ input_socket_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] = frozendict(),
+ inscks_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] | None = None,
+ input_sockets_optional: frozenset[ct.SocketName] = frozendict(),
+ # Loading: Output Sockets
+ output_sockets: frozenset[ct.SocketName] = frozenset(),
+ output_socket_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] = frozendict(),
+ outscks_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] | None = None,
+ output_sockets_optional: frozenset[ct.SocketName] = frozenset(),
+ # Loading: Loose Sockets
all_loose_input_sockets: bool = False,
+ loose_input_sockets_kind: frozenset[FK] | FK | None = None,
all_loose_output_sockets: bool = False,
- # Request Unit System Scaling
- unit_systems: dict[UnitSystemID, spux.UnitSystem] = MappingProxyType({}),
- scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
- scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
+ loose_output_sockets_kind: frozenset[FK] | FK | None = None,
+ # Loading: Unit System Scaling
+ scale_props: frozendict[ct.PropName, spux.UnitSystem] = frozendict(),
+ scale_input_sockets: frozendict[ct.SocketName, spux.UnitSystem] | None = None,
+ scale_output_sockets: frozendict[ct.SocketName, spux.UnitSystem] | None = None,
+ scale_loose_input_sockets: spux.UnitSystem | None = None,
+ scale_loose_output_sockets: spux.UnitSystem | None = None,
):
"""Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
@@ -101,25 +458,22 @@ def event_decorator( # noqa: PLR0913
event: A name describing which event the decorator should respond to.
callback_info: A dictionary that provides the caller with additional per-`event` information.
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
- stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
- Other methods defined on the same node will still run.
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
props: Set of `props` to compute, then pass to the decorated method.
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
- input_sockets_optional: Whether an input socket is required to exist.
- When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
- input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
- If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
+ input_sockets_optional: Allow the method will run even if one of these input socket values are a `ct.FlowSignal`.
+ input_socket_kinds: The `FK` to compute per-input-socket.
+ If an input socket isn't specified, it defaults to `FK.Value`.
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
- output_sockets_optional: Whether an output socket is required to exist.
+ output_sockets_optional: Allow the method will run even if one of these output socket values are a `ct.FlowSignal`.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
- output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
- If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
+ output_socket_kinds: The `FK` to compute per-output-socket.
+ If an output socket isn't specified, it defaults to `FK.Value`.
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
Used when the names of the loose input sockets are unknown, but all of their values are needed.
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
Used when the names of the loose output sockets are unknown, but all of their values are needed.
- unit_systems: String identifiers under which to load a unit system, made available to the method.
+ scale_props: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
@@ -130,204 +484,152 @@ def event_decorator( # noqa: PLR0913
"""
req_params = (
{'self'}
- | ({'props'} if props else set())
- | ({'managed_objs'} if managed_objs else set())
- | ({'input_sockets'} if input_sockets else set())
- | ({'output_sockets'} if output_sockets else set())
- | ({'loose_input_sockets'} if all_loose_input_sockets else set())
- | ({'loose_output_sockets'} if all_loose_output_sockets else set())
- | ({'unit_systems'} if unit_systems else set())
+ | ({'props'} if props else frozenset())
+ | ({'managed_objs'} if managed_objs else frozenset())
+ | ({'input_sockets'} if input_sockets or inscks_kinds else frozenset())
+ | ({'output_sockets'} if output_sockets or outscks_kinds else frozenset())
+ | ({'loose_input_sockets'} if all_loose_input_sockets else frozenset())
+ | ({'loose_output_sockets'} if all_loose_output_sockets else frozenset())
)
- # TODO: Check that all Unit System IDs referenced are also defined in 'unit_systems'.
+ # Simplify I/O Under Naming
+ if inscks_kinds is None:
+ inscks_kinds = frozendict(
+ {
+ socket: input_socket_kinds.get(socket, FK.Value)
+ for socket in input_sockets
+ }
+ )
+ inscks_unit_system = (
+ frozendict(scale_input_sockets) if scale_input_sockets is not None else None
+ )
+ inscks_optional = frozenset(input_sockets_optional)
+
+ if outscks_kinds is None:
+ outscks_kinds = frozendict(
+ {
+ output_socket: output_socket_kinds.get(output_socket, FK.Value)
+ for output_socket in output_sockets
+ }
+ )
+ outscks_unit_system = (
+ frozendict(scale_output_sockets) if scale_output_sockets is not None else None
+ )
+ outscks_optional = frozenset(output_sockets_optional)
+
## TODO: More ex. introspective checks and such, to make it really hard to write invalid methods.
# TODO: Check Function Annotation Validity
## - socket capabilities
- def decorator(method: typ.Callable) -> typ.Callable:
+ def decorator(method: typ.Callable) -> typ.Callable: # noqa: C901
# Check Function Signature Validity
- func_sig = set(inspect.signature(method).parameters.keys())
+ func_sig = frozenset(inspect.signature(method).parameters.keys())
- ## Too Few Arguments
+ ## -> Too Few Arguments
if func_sig != req_params and func_sig.issubset(req_params):
msg = f'Decorated method {method.__name__} is missing arguments {req_params - func_sig}'
- ## Too Many Arguments
+ ## -> Too Many Arguments
if func_sig != req_params and func_sig.issuperset(req_params):
msg = f'Decorated method {method.__name__} has superfluous arguments {func_sig - req_params}'
raise ValueError(msg)
- def decorated(node):
- method_kw_args = {} ## Keyword Arguments for Decorated Method
-
- # Unit Systems
- method_kw_args |= {'unit_systems': unit_systems} if unit_systems else {}
-
- # Properties
- method_kw_args |= (
- {'props': {prop_name: getattr(node, prop_name) for prop_name in props}}
- if props
- else {}
- )
+ def decorated(node: bpy.types.Node): # noqa: C901, PLR0912
+ method_kwargs = defaultdict(dict)
# Managed Objects
- method_kw_args |= (
- {
- 'managed_objs': {
- managed_obj_name: node.managed_objs[managed_obj_name]
- for managed_obj_name in managed_objs
- }
- }
- if managed_objs
- else {}
- )
+ if managed_objs:
+ method_kwargs['managed_objs'] = {}
+ for mobj_name, mobj in parse_node_mobjs(node, managed_objs).items():
+ method_kwargs['managed_objs'] |= {mobj_name: mobj}
+
+ # Properties
+ if props:
+ method_kwargs['props'] = {}
+ for prop, value in parse_node_props(node, props, scale_props).items():
+ method_kwargs['props'] |= {prop: value}
# Sockets
- ## Input Sockets
- method_kw_args |= (
- {
- 'input_sockets': {
- input_socket_name: node._compute_input(
- input_socket_name,
- kind=_kind,
- unit_system=(
- unit_systems.get(
- scale_input_sockets.get(input_socket_name)
- )
- ),
- optional=input_sockets_optional.get(
- input_socket_name, False
- ),
- )
- if not isinstance(
- _kind := input_socket_kinds.get(
- input_socket_name, ct.FlowKind.Value
- ),
- set,
- )
- else {
- kind: node._compute_input(
- input_socket_name,
- kind=kind,
- unit_system=unit_systems.get(
- scale_input_sockets.get(input_socket_name)
- ),
- optional=input_sockets_optional.get(
- input_socket_name, False
- ),
- )
- for kind in _kind
- }
- for input_socket_name in input_sockets
- }
- }
- if input_sockets
- else {}
- )
+ if inscks_kinds:
+ method_kwargs['input_sockets'] = {}
+ for insck, flow in parse_node_scks_kinds(
+ node,
+ 'input',
+ inscks_kinds,
+ inscks_unit_system,
+ ).items():
+ has_flow = not FS.check(flow)
+ if has_flow or insck in inscks_optional:
+ method_kwargs['input_sockets'] |= {insck: flow}
+ else:
+ flow_signal = flow
+ return flow_signal # noqa: RET504
- ## Output Sockets
- def _g_output_socket(output_socket_name: ct.SocketName, kind: ct.FlowKind):
- if scale_output_sockets.get(output_socket_name) is None:
- return node.compute_output(
- output_socket_name,
- kind=kind,
- optional=output_sockets_optional.get(output_socket_name, False),
- )
-
- return ct.FlowKind.scale_to_unit_system(
- kind,
- node.compute_output(
- output_socket_name,
- kind=kind,
- optional=output_sockets_optional.get(output_socket_name, False),
- ),
- unit_systems.get(scale_output_sockets.get(output_socket_name)),
- )
-
- method_kw_args |= (
- {
- 'output_sockets': {
- output_socket_name: _g_output_socket(output_socket_name, _kind)
- if not isinstance(
- _kind := output_socket_kinds.get(
- output_socket_name, ct.FlowKind.Value
- ),
- set,
- )
- else {
- kind: _g_output_socket(output_socket_name, kind)
- for kind in _kind
- }
- for output_socket_name in output_sockets
- }
- }
- if output_sockets
- else {}
- )
+ if outscks_kinds:
+ method_kwargs['output_sockets'] = {}
+ for outsck, flow in parse_node_scks_kinds(
+ node,
+ 'output',
+ outscks_kinds,
+ outscks_unit_system,
+ ).items():
+ has_flow = not FS.check(flow)
+ if has_flow or outsck in outscks_optional:
+ method_kwargs['output_sockets'] |= {outsck: flow}
+ else:
+ flow_signal = flow
+ return flow_signal # noqa: RET504
# Loose Sockets
- ## -> Determined by the active_kind of each loose input socket.
- method_kw_args |= (
- {
- 'loose_input_sockets': {
- input_socket_name: node._compute_input(
- input_socket_name,
- kind=node.inputs[input_socket_name].active_kind,
- )
- for input_socket_name in node.loose_input_sockets
- }
+ if all_loose_input_sockets:
+ method_kwargs['loose_input_sockets'] = {}
+ loose_inscks = frozenset(node.loose_input_sockets.keys())
+ loose_inscks_kinds = {
+ loose_insck: loose_input_sockets_kind
+ for loose_insck in loose_inscks
}
- if all_loose_input_sockets
- else {}
- )
- ## Compute All Loose Output Sockets
- method_kw_args |= (
- {
- 'loose_output_sockets': {
- output_socket_name: node.compute_output(
- output_socket_name,
- kind=node.outputs[output_socket_name].active_kind,
- )
- for output_socket_name in node.loose_output_sockets
- }
+ for loose_insck, flow in parse_node_scks_kinds(
+ node,
+ 'input',
+ loose_inscks_kinds,
+ scale_loose_input_sockets,
+ ).items():
+ method_kwargs['loose_input_sockets'] |= {loose_insck: flow}
+
+ if all_loose_output_sockets:
+ method_kwargs['loose_output_sockets'] = {}
+ loose_outscks = frozenset(node.loose_output_sockets.keys())
+ loose_outscks_kinds = {
+ loose_outsck: loose_output_sockets_kind
+ for loose_outsck in loose_outscks
}
- if all_loose_output_sockets
- else {}
- )
- # Propagate Initialization
- ## If there is a FlowInitializing, then the method would fail.
- ## Therefore, propagate FlowInitializing if found.
- if any(
- ct.FlowSignal.check_single(value, ct.FlowSignal.FlowInitializing)
- for sockets in [
- method_kw_args.get('input_sockets', {}),
- method_kw_args.get('loose_input_sockets', {}),
- method_kw_args.get('output_sockets', {}),
- method_kw_args.get('loose_output_sockets', {}),
- ]
- for value in sockets.values()
- ):
- return ct.FlowSignal.FlowInitializing
+ for loose_outsck, flow in parse_node_scks_kinds(
+ node,
+ 'input',
+ loose_outscks_kinds,
+ scale_loose_output_sockets,
+ ).items():
+ method_kwargs['loose_output_sockets'] |= {loose_outsck: flow}
# Call Method
return method(
node,
- **method_kw_args,
+ **method_kwargs,
)
- # Set Decorated Attributes and Return
- ## TODO: Fix Introspection + Documentation
- # decorated.__name__ = method.__name__
+ # Wrap Decorated Method Attributes
+ ## -> We can't just @wraps(), since the call signature changed.
+ decorated.__name__ = method.__name__
# decorated.__module__ = method.__module__
# decorated.__qualname__ = method.__qualname__
decorated.__doc__ = method.__doc__
- ## Add Spice
+ # Add Spice
+ decorated.identifier = EVENT_METHOD_IDENTIFIER
decorated.event = event
decorated.callback_info = callback_info
- decorated.stop_propagation = stop_propagation
return decorated
@@ -335,11 +637,12 @@ def event_decorator( # noqa: PLR0913
####################
-# - Simplified Event Callbacks
+# - Specific Event Callbacks
####################
def on_enable_lock(
**kwargs,
):
+ """Declare a method that reacts to the enabling of the interface lock."""
return event_decorator(
event=ct.FlowEvent.EnableLock,
callback_info=None,
@@ -350,6 +653,7 @@ def on_enable_lock(
def on_disable_lock(
**kwargs,
):
+ """Declare a method that reacts to the disabling of the interface lock."""
return event_decorator(
event=ct.FlowEvent.DisableLock,
callback_info=None,
@@ -359,28 +663,50 @@ def on_disable_lock(
## TODO: Consider changing socket_name and prop_name to more obvious names.
def on_value_changed(
- socket_name: set[ct.SocketName] | ct.SocketName | None = None,
- prop_name: set[str] | str | None = None,
+ prop_name: frozenset[str] | str | None = None,
+ socket_name: frozendict[ct.SocketName, frozenset[FK]]
+ | frozendict[ct.SocketName, FK]
+ | frozenset[ct.SocketName]
+ | ct.SocketName
+ | None = None,
+ socket_name_kinds: frozenset[FK] | FK = frozenset(FK),
any_loose_input_socket: bool = False,
run_on_init: bool = False,
+ stop_propagation: bool = False,
**kwargs,
):
+ """Declare a method that reacts to a change in node data.
+
+ Can be configured to react to changes in:
+
+ - Any particular input `socket|kind`s.
+ - Any `loose_socket|active_kind`.
+ - Any particular property.
+
+ In addition, the method can be configured to run when its node is created and initialized for the first time.
+ """
return event_decorator(
event=ct.FlowEvent.DataChanged,
callback_info=InfoDataChanged(
- # Triggers
- run_on_init=run_on_init,
- on_changed_sockets=(
- socket_name if isinstance(socket_name, set) else {socket_name}
+ # Trigger: Props
+ on_changed_props=(
+ prop_name
+ if isinstance(prop_name, set | frozenset)
+ else ({prop_name} if prop_name is not None else set())
),
- on_changed_props=(prop_name if isinstance(prop_name, set) else {prop_name}),
- on_any_changed_loose_input=any_loose_input_socket,
- # Loaded
- must_load_sockets={
+ # Trigger: Sockets
+ on_changed_sockets_kinds=mk_sockets_kinds(socket_name, socket_name_kinds),
+ # Trigger: Loose Sockets
+ on_any_changed_loose_socket=any_loose_input_socket,
+ # Trigger: Init
+ run_on_init=run_on_init,
+ # Hints
+ optional_sockets={
socket_name
- for socket_name in kwargs.get('input_sockets', {})
- if socket_name not in kwargs.get('input_sockets_optional', {})
+ for socket_name in kwargs.get('input_sockets', set())
+ if socket_name in kwargs.get('input_sockets_optional', set())
},
+ stop_propagation=stop_propagation,
),
**kwargs,
)
@@ -388,45 +714,63 @@ def on_value_changed(
def computes_output_socket(
output_socket_name: ct.SocketName | None,
- kind: ct.FlowKind = ct.FlowKind.Value,
+ kind: FK = FK.Value,
**kwargs,
):
+ """Declare a method used to compute the value of a particular output `socket|kind`.
+
+ The method's dependencies on properties, input/output `socket|kind`s, and loose input/output `socket|active_kind`s are recorded in its `callback_info`, such that the associated output socket cache can be appropriately invalidated whenever an input dependency changes.
+ """
+ input_socket_kinds = kwargs.get('input_socket_kinds', {})
+ depon_inscks_kinds = {
+ insck: input_socket_kinds.get(insck, FK.Value)
+ for insck in kwargs.get('input_sockets', set())
+ } | kwargs.get('inscks_kinds', {})
+
+ output_sockets_kinds = kwargs.get('output_sockets_kinds', {})
+ depon_outscks_kinds = {
+ outsck: output_sockets_kinds.get(outsck, FK.Value)
+ for outsck in kwargs.get('output_sockets', set())
+ } | kwargs.get('outscks_kinds', {})
+
+ log.debug(
+ [
+ output_socket_name,
+ kind,
+ kwargs.get('props', set()),
+ mk_sockets_kinds(depon_inscks_kinds),
+ mk_sockets_kinds(depon_outscks_kinds),
+ kwargs.get('all_loose_input_sockets', False),
+ kwargs.get('all_loose_output_sockets', False),
+ ]
+ )
return event_decorator(
event=ct.FlowEvent.OutputRequested,
callback_info=InfoOutputRequested(
output_socket_name=output_socket_name,
kind=kind,
+ # Dependency: Props
depon_props=kwargs.get('props', set()),
- depon_input_sockets=kwargs.get('input_sockets', set()),
- depon_input_socket_kinds=kwargs.get('input_socket_kinds', {}),
- depon_output_sockets=kwargs.get('output_sockets', set()),
- depon_output_socket_kinds=kwargs.get('output_socket_kinds', {}),
+ # Dependency: Input Sockets
+ depon_input_sockets_kinds=mk_sockets_kinds(depon_inscks_kinds),
+ # Dependency: Output Sockets
+ depon_output_sockets_kinds=mk_sockets_kinds(depon_outscks_kinds),
+ # Dependency: Loose Sockets
depon_all_loose_input_sockets=kwargs.get('all_loose_input_sockets', False),
depon_all_loose_output_sockets=kwargs.get(
'all_loose_output_sockets', False
),
),
- **kwargs, ## stop_propagation has no effect.
- )
-
-
-def on_show_preview(
- **kwargs,
-):
- return event_decorator(
- event=ct.FlowEvent.ShowPreview,
- callback_info={},
**kwargs,
)
def on_show_plot(
- stop_propagation: bool = True,
**kwargs,
):
+ """Declare a method that reacts to a request to show a plot."""
return event_decorator(
event=ct.FlowEvent.ShowPlot,
- callback_info={},
- stop_propagation=stop_propagation,
+ callback_info=None,
**kwargs,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py
index 163ba23..611fb26 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/__init__.py
@@ -18,21 +18,16 @@ from . import (
constants,
file_importers,
scene,
- wave_constant,
web_importers,
)
-# from . import file_importers
-
BL_REGISTER = [
- *wave_constant.BL_REGISTER,
- *scene.BL_REGISTER,
*constants.BL_REGISTER,
*web_importers.BL_REGISTER,
*file_importers.BL_REGISTER,
+ *scene.BL_REGISTER,
]
BL_NODES = {
- **wave_constant.BL_NODES,
**scene.BL_NODES,
**constants.BL_NODES,
**web_importers.BL_NODES,
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/blender_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/blender_constant.py
index 8b9a033..9b5d23f 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/blender_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/blender_constant.py
@@ -20,6 +20,9 @@ from .... import contracts as ct
from .... import sockets
from ... import base, events
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class BlenderConstantNode(base.MaxwellSimNode):
node_type = ct.NodeType.BlenderConstant
@@ -47,7 +50,7 @@ class BlenderConstantNode(base.MaxwellSimNode):
####################
# - Callbacks
####################
- @events.computes_output_socket('Value', input_sockets={'Value'})
+ @events.computes_output_socket('Value', kind=FK.Params, input_sockets={'Value'})
def compute_value(self, input_sockets) -> typ.Any:
return input_sockets['Value']
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/expr_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/expr_constant.py
index 75ebb98..cdb11cc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/expr_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/expr_constant.py
@@ -16,55 +16,55 @@
import typing as typ
-import sympy as sp
-
from .... import contracts as ct
from .... import sockets
from ... import base, events
+FK = ct.FlowKind
+
class ExprConstantNode(base.MaxwellSimNode):
+ """An expression constant."""
+
node_type = ct.NodeType.ExprConstant
bl_label = 'Expr Constant'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Func,
+ active_kind=FK.Func,
show_name_selector=True,
),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Func,
+ active_kind=FK.Func,
show_info_columns=True,
),
}
- ## TODO: Allow immediately realizing any symbol, or just passing it along.
- ## TODO: Alter output physical_type when the input PhysicalType changes.
-
####################
# - FlowKinds
####################
@events.computes_output_socket(
# Trigger
'Expr',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- input_sockets={'Expr'},
+ inscks_kinds={'Expr': FK.Value},
)
def compute_value(self, input_sockets: dict) -> typ.Any:
+ """Compute the symbolic expression value."""
return input_sockets['Expr']
@events.computes_output_socket(
# Trigger
'Expr',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Func},
+ inscks_kinds={'Expr': FK.Func},
)
def compute_lazy_func(self, input_sockets: dict) -> typ.Any:
+ """Compute the lazy expression function."""
return input_sockets['Expr']
####################
@@ -73,23 +73,25 @@ class ExprConstantNode(base.MaxwellSimNode):
@events.computes_output_socket(
# Trigger
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
# Loaded
input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Info},
+ input_socket_kinds={'Expr': FK.Info},
)
def compute_info(self, input_sockets: dict) -> typ.Any:
+ """Compute the tracking information flow."""
return input_sockets['Expr']
@events.computes_output_socket(
# Trigger
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
input_sockets={'Expr'},
- input_socket_kinds={'Expr': ct.FlowKind.Params},
+ input_socket_kinds={'Expr': FK.Params},
)
def compute_params(self, input_sockets: dict) -> typ.Any:
+ """Compute the fnuction parameters."""
return input_sockets['Expr']
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
index 11068e3..7da78a4 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/scientific_constant.py
@@ -14,11 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `ScientificConstantNode`."""
+
import typing as typ
import bpy
import sympy as sp
-import sympy.physics.units as spu
+from frozendict import frozendict
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@@ -55,6 +57,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
self,
edit_text: str,
):
+ """Search for all valid scientific constants."""
return [
name
for name in sci_constants.SCI_CONSTANTS
@@ -82,7 +85,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
self.sci_constant_name,
self.sci_constant,
unit,
- is_constant=True,
+ new_domain=spux.BlessedSet(sp.FiniteSet(self.sci_constant)),
)
return None
@@ -91,11 +94,13 @@ class ScientificConstantNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self):
+ """Match the node's header label to the active scientific constant."""
if self.sci_constant_str:
return f'Const: {self.sci_constant_str}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
+ """Provide a node user-interface allowing the user to specify a scientific constant."""
col.prop(self, self.blfields['sci_constant_str'], text='')
row = col.row(align=True)
@@ -139,7 +144,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
####################
- # - Output
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Expr',
@@ -148,8 +153,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
def compute_value(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
+ use_symbol = props['use_symbol']
- if props['use_symbol'] and sci_constant_sym is not None:
+ if use_symbol and sci_constant_sym is not None:
return sci_constant_sym.sp_symbol
if sci_constant is not None:
@@ -157,6 +163,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
+ ####################
+ # - FlowKind.Func
+ ####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@@ -178,19 +187,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
)
return ct.FlowSignal.FlowPending
- @events.computes_output_socket(
- 'Expr',
- kind=ct.FlowKind.Info,
- props={'sci_constant_sym'},
- )
- def compute_info(self, props: dict) -> typ.Any:
- """Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
- sci_constant_sym = props['sci_constant_sym']
-
- if sci_constant_sym is not None:
- return ct.InfoFlow(output=sci_constant_sym)
- return ct.FlowSignal.FlowPending
-
+ ####################
+ # - FlowKind.Params
+ ####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
@@ -206,12 +205,30 @@ class ScientificConstantNode(base.MaxwellSimNode):
func_args=[sci_constant_sym.sp_symbol],
symbols={sci_constant_sym},
).realize_partial(
- {
- sci_constant_sym: sci_constant,
- }
+ frozendict(
+ {
+ sci_constant_sym: sci_constant,
+ }
+ )
)
return ct.FlowSignal.FlowPending
+ ####################
+ # - FlowKind.Info
+ ####################
+ @events.computes_output_socket(
+ 'Expr',
+ kind=ct.FlowKind.Info,
+ props={'sci_constant_sym'},
+ )
+ def compute_info(self, props: dict) -> typ.Any:
+ """Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
+ sci_constant_sym = props['sci_constant_sym']
+
+ if sci_constant_sym is not None:
+ return ct.InfoFlow(output=sci_constant_sym)
+ return ct.FlowSignal.FlowPending
+
####################
# - Blender Registration
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
index 213746a..7bc086f 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `SymbolConstantNode`."""
+
import enum
import typing as typ
from fractions import Fraction
@@ -32,6 +34,8 @@ log = logger.get(__name__)
class SymbolConstantNode(base.MaxwellSimNode):
+ """A symbol usable as itself, or as a symbol."""
+
node_type = ct.NodeType.SymbolConstant
bl_label = 'Symbol'
@@ -87,7 +91,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
return None
- @property
+ @bl_cache.cached_bl_property(depends_on={'unit'})
def unit_factor(self) -> spux.Unit | None:
"""Like `self.unit`, except `1` instead of `None` when unitless."""
return sp.Integer(1) if self.unit is None else self.unit
@@ -95,6 +99,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
####################
# - Domain
####################
+ exclude_zero: bool = bl_cache.BLField(False)
+
interval_finite_z: tuple[int, int] = bl_cache.BLField((0, 1))
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = bl_cache.BLField(
((0, 1), (1, 1))
@@ -146,22 +152,27 @@ class SymbolConstantNode(base.MaxwellSimNode):
'preview_value_q',
'preview_value_re',
'preview_value_im',
+ 'unit_factor',
}
)
- def preview_value(
+ def preview_value_phy(
self,
) -> int | Fraction | float | complex:
"""Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
MT = spux.MathType
match self.mathtype:
case MT.Integer:
- return self.preview_value_z
+ preview_value = sp.Integer(self.preview_value_z)
case MT.Rational:
- return Fraction(*self.preview_value_q)
+ preview_value = sp.Rational(*self.preview_value_q)
case MT.Real:
- return self.preview_value_re
+ preview_value = sp.Float(self.preview_value_re, 4)
case MT.Complex:
- return complex(self.preview_value_re, self.preview_value_im)
+ preview_value = sp.Float(self.preview_value_re) + sp.I * sp.Float(
+ self.preview_value_im
+ )
+
+ return preview_value * self.unit_factor
@bl_cache.cached_bl_property(
depends_on={
@@ -179,11 +190,27 @@ class SymbolConstantNode(base.MaxwellSimNode):
"""Deduce the domain specified in the UI."""
MT = spux.MathType
match self.mathtype:
- case MT.Integer | MT.Real | MT.Rational:
- return sim_symbols.mk_interval(
- self.interval_finite,
- self.interval_inf,
- self.interval_closed,
+ case MT.Integer:
+ start = (
+ self.interval_inf[0]
+ if self.interval_inf[0]
+ else self.interval_finite[0]
+ )
+ end = (
+ self.interval_inf[1]
+ if self.interval_inf[1]
+ else self.interval_finite[1]
+ )
+
+ return spux.BlessedSet(sp.Range(start, end, 1))
+
+ case MT.Real | MT.Rational:
+ return spux.BlessedSet(
+ sim_symbols.mk_interval(
+ self.interval_finite,
+ self.interval_inf,
+ self.interval_closed,
+ )
)
case MT.Complex:
@@ -198,7 +225,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
self.interval_inf_im,
self.interval_closed_im,
)
- return sp.ComplexRegion(domain_re, domain_im, polar=False)
+ return spux.BlessedSet(spux.ComplexRegion(domain_re, domain_im))
+
+ @bl_cache.cached_bl_property(depends_on={'domain'})
+ def is_nonzero(self) -> bool:
+ return 0 in self.domain
####################
# - Computed Properties
@@ -211,7 +242,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
'unit',
'size',
'domain',
- 'preview_value',
}
)
def symbol(self) -> sim_symbols.SimSymbol:
@@ -224,7 +254,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
rows=self.size.rows,
cols=self.size.cols,
domain=self.domain,
- preview_value=self.preview_value,
)
####################
@@ -268,6 +297,14 @@ class SymbolConstantNode(base.MaxwellSimNode):
row.label(text='Domain - Closure')
row = col.row(align=True)
+ match self.mathtype:
+ case spux.MathType.Rational | spux.MathType.Real:
+ row.prop(self, self.blfields['interval_closed'], text='')
+
+ case spux.MathType.Complex:
+ row.prop(self, self.blfields['interval_closed'], text='ℝ')
+ row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
+
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_closed'], text='ℝ')
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
@@ -353,7 +390,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
)
def compute_info(self, props) -> typ.Any:
return ct.InfoFlow(
- dims={props['symbol']: None},
output=props['symbol'],
)
@@ -362,14 +398,14 @@ class SymbolConstantNode(base.MaxwellSimNode):
'Expr',
kind=ct.FlowKind.Params,
# Loaded
- props={'symbol'},
+ props={'symbol', 'preview_value_phy'},
)
def compute_params(self, props) -> typ.Any:
sym = props['symbol']
return ct.ParamsFlow(
- arg_targets=[sym],
func_args=[sym.sp_symbol],
symbols={sym},
+ previewed_symbols={sym: props['preview_value_phy']},
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
index e9134fc..71085f3 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/data_file_importer.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `DataFileImporterNode`."""
+
import enum
import typing as typ
from pathlib import Path
@@ -31,11 +33,15 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+
####################
# - Node
####################
class DataFileImporterNode(base.MaxwellSimNode):
+ """Import expression data from a file."""
+
node_type = ct.NodeType.DataFileImporter
bl_label = 'Data File Importer'
@@ -43,7 +49,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
'File Path': sockets.FilePathSocketDef(),
}
output_sockets: typ.ClassVar = {
- 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@@ -54,7 +60,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
socket_name={'File Path'},
# Loaded
input_sockets={'File Path'},
- input_socket_kinds={'File Path': ct.FlowKind.Value},
+ input_socket_kinds={'File Path': FK.Value},
input_sockets_optional={'File Path': True},
# Flow
## -> See docs in TransformMathNode
@@ -69,9 +75,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def file_path(self) -> Path:
"""Retrieve the input file path."""
- file_path = self._compute_input(
- 'File Path', kind=ct.FlowKind.Value, optional=True
- )
+ file_path = self._compute_input('File Path', kind=FK.Value, optional=True)
has_file_path = not ct.FlowSignal.check(file_path)
if has_file_path:
return file_path
@@ -99,7 +103,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
)
def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the output expression's `InfoFlow`."""
- info = self.compute_output('Expr', kind=ct.FlowKind.Info)
+ info = self.compute_output('Expr', kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info
@@ -120,6 +124,15 @@ class DataFileImporterNode(base.MaxwellSimNode):
cb_depends_on={'output_physical_type'},
)
+ def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
+ """Determine valid units based on the given physical type."""
+ if physical_type is not spux.PhysicalType.NonPhysical:
+ return [
+ (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
+ for i, unit in enumerate(physical_type.valid_units)
+ ]
+ return []
+
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
@@ -139,14 +152,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
sim_symbols.SimSymbolName.LowerF
)
- def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
- if physical_type is not spux.PhysicalType.NonPhysical:
- return [
- (sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
- for i, unit in enumerate(physical_type.valid_units)
- ]
- return []
-
####################
# - UI
####################
@@ -196,10 +201,15 @@ class DataFileImporterNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
- input_sockets={'File Path'},
+ props={
+ 'output_name',
+ 'output_mathtype',
+ 'output_physical_type',
+ 'output_unit',
+ },
+ inscks_kinds={'File Path': FK.Func},
)
def compute_func(self, props, input_sockets) -> td.Simulation:
"""Declare a lazy, composable function that returns the loaded data.
@@ -244,7 +254,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
)
def compute_params(self) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
@@ -256,12 +266,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Expr',
- kind=ct.FlowKind.Info,
+ kind=FK.Info,
# Loaded
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
| {f'dim_{i}_name' for i in range(6)},
output_sockets={'Expr'},
- output_socket_kinds={'Expr': ct.FlowKind.Func},
+ output_socket_kinds={'Expr': FK.Func},
)
def compute_info(self, props, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape.
@@ -284,9 +294,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
dims = {
sim_symbols.idx(None).update(
sym_name=props[f'dim_{i}_name'],
- interval_finite_z=(0, elements),
- interval_inf=(False, False),
- interval_closed=(True, True),
+ domain=spux.BlessedSet(sp.Range(0, elements)),
): [str(j) for j in range(elements)]
for i, elements in enumerate(shape)
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/tidy_3d_file_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/tidy_3d_file_importer.py
index 6b0473a..80c8dc0 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/tidy_3d_file_importer.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/file_importers/tidy_3d_file_importer.py
@@ -14,14 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import enum
+import functools
import typing as typ
from pathlib import Path
import bpy
import tidy3d as td
-import tidy3d.plugins.dispersion as td_dispersion
-from blender_maxwell.utils import logger
+from blender_maxwell.utils import bl_cache, logger
from .... import contracts as ct
from .... import managed_objs, sockets
@@ -29,35 +30,94 @@ from ... import base, events
log = logger.get(__name__)
-VALID_FILE_EXTS = {
- 'SIMULATION_DATA': {
- '.hdf5.gz',
- '.hdf5',
- },
- 'SIMULATION': {
- '.hdf5.gz',
- '.hdf5',
- '.json',
- '.yaml',
- },
- 'MEDIUM': {
- '.hdf5.gz',
- '.hdf5',
- '.json',
- '.yaml',
- },
- 'EXPERIM_DISP_MEDIUM': {
- '.txt',
- },
-}
+FK = ct.FlowKind
+FS = ct.FlowSignal
-CACHE = {}
+
+class ValidTDFileExts(enum.StrEnum):
+ """Valid importable Tidy3D file extensions."""
+
+ SimData = enum.auto()
+ Sim = enum.auto()
+ Medium = enum.auto()
+
+ @staticmethod
+ def to_name(value: typ.Self) -> str:
+ VFE = ValidTDFileExts
+ return {
+ VFE.SimData: 'Sim Data',
+ VFE.Sim: 'Sim',
+ VFE.Medium: 'Medium',
+ }[value]
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ return ''
+
+ def bl_enum_element(self, i: int) -> ct.BLEnumElement:
+ return (
+ str(self),
+ ValidTDFileExts.to_name(self),
+ ValidTDFileExts.to_name(self),
+ ValidTDFileExts.to_icon(self),
+ i,
+ )
+
+ ####################
+ # - Properties
+ ####################
+ @functools.cached_property
+ def valid_exts(self) -> set[str]:
+ VFE = ValidTDFileExts
+ match self:
+ case VFE.SimData:
+ return {
+ '.hdf5.gz',
+ '.hdf5',
+ }
+
+ case VFE.Sim | VFE.Medium:
+ return {
+ '.hdf5.gz',
+ '.hdf5',
+ '.json',
+ '.yaml',
+ }
+
+ raise TypeError
+
+ @functools.cached_property
+ def td_type(self) -> set[str]:
+ """The corresponding Tidy3D type."""
+ VFE = ValidTDFileExts
+ return {
+ VFE.SimData: td.SimulationData,
+ VFE.Sim: td.Simulation,
+ VFE.Medium: td.Medium,
+ }[self]
+
+ ####################
+ # - Methods
+ ####################
+ def is_path_compatible(self, path: Path) -> bool:
+ ext_matches = ''.join(path.suffixes) in self.valid_exts
+ return ext_matches and path.is_file()
+
+ def load(self, file_path: Path) -> typ.Any:
+ VFE = ValidTDFileExts
+ match self:
+ case VFE.SimData | VFE.Sim | VFE.Medium:
+ return self.td_type.from_file(str(file_path))
+
+ raise TypeError
####################
# - Node
####################
class Tidy3DFileImporterNode(base.MaxwellSimNode):
+ """Import a simulation design or analysis element from a Tidy3D object."""
+
node_type = ct.NodeType.Tidy3DFileImporter
bl_label = 'Tidy3D File Importer'
@@ -71,185 +131,74 @@ class Tidy3DFileImporterNode(base.MaxwellSimNode):
####################
# - Properties
####################
- ## TODO: More automatic determination of which file type is in use :)
- tidy3d_type: bpy.props.EnumProperty(
- name='Tidy3D Type',
- description='Type of Tidy3D object to load',
- items=[
- (
- 'SIMULATION_DATA',
- 'Sim Data',
- 'Data from Completed Tidy3D Simulation',
- ),
- ('SIMULATION', 'Sim', 'Tidy3D Simulation'),
- ('MEDIUM', 'Medium', 'A Tidy3D Medium'),
- (
- 'EXPERIM_DISP_MEDIUM',
- 'Experim Disp Medium',
- 'A pole-residue fit of experimental dispersive medium data, described by a .txt file specifying wl, n, k',
- ),
- ],
- default='SIMULATION_DATA',
- update=lambda self, context: self.on_prop_changed('tidy3d_type', context),
- )
-
- disp_fit__min_poles: bpy.props.IntProperty(
- name='min Poles',
- description='Min. # poles to fit to the experimental dispersive medium data',
- default=1,
- )
- disp_fit__max_poles: bpy.props.IntProperty(
- name='max Poles',
- description='Max. # poles to fit to the experimental dispersive medium data',
- default=5,
- )
- ## TODO: Bool of whether to fit eps_inf, with conditional choice of eps_inf as socket
- disp_fit__tolerance_rms: bpy.props.FloatProperty(
- name='Max RMS',
- description='The RMS error threshold, below which the fit should be considered converged',
- default=0.001,
- precision=5,
- )
- ## TODO: "AdvanceFastFitterParam" options incl. loss_bounds, weights, show_progress, show_unweighted_rms, relaxed, smooth, logspacing, numiters, passivity_num_iters, and slsqp_constraint_scale
+ tidy3d_type: ValidTDFileExts = bl_cache.BLField(ValidTDFileExts.SimData)
+ ####################
+ # - UI
+ ####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, 'tidy3d_type', text='')
- if self.tidy3d_type == 'EXPERIM_DISP_MEDIUM':
- row = col.row(align=True)
- row.alignment = 'CENTER'
- row.label(text='Pole-Residue Fit')
-
- col.prop(self, 'disp_fit__min_poles')
- col.prop(self, 'disp_fit__max_poles')
- col.prop(self, 'disp_fit__tolerance_rms')
-
- ####################
- # - Event Methods: Output Data
- ####################
- def _compute_sim_data_for(
- self, output_socket_name: str, file_path: Path
- ) -> td.components.base.Tidy3dBaseModel:
- return {
- 'Sim': td.Simulation,
- 'Sim Data': td.SimulationData,
- 'Medium': td.Medium,
- }[output_socket_name].from_file(str(file_path))
-
- @events.computes_output_socket(
- 'Sim',
- input_sockets={'File Path'},
- )
- def compute_sim(self, input_sockets: dict) -> td.Simulation:
- return self._compute_sim_data_for('Sim', input_sockets['File Path'])
-
- @events.computes_output_socket(
- 'Sim Data',
- input_sockets={'File Path'},
- )
- def compute_sim_data(self, input_sockets: dict) -> td.SimulationData:
- return self._compute_sim_data_for('Sim Data', input_sockets['File Path'])
-
- @events.computes_output_socket(
- 'Medium',
- input_sockets={'File Path'},
- )
- def compute_medium(self, input_sockets: dict) -> td.Medium:
- return self._compute_sim_data_for('Medium', input_sockets['File Path'])
-
- ####################
- # - Event Methods: Output Data | Dispersive Media
- ####################
- @events.computes_output_socket(
- 'Experim Disp Medium',
- input_sockets={'File Path'},
- )
- def compute_experim_disp_medium(self, input_sockets: dict) -> td.Medium:
- if CACHE.get(self.bl_label) is not None:
- log.debug('Reusing Cached Dispersive Medium')
- return CACHE[self.bl_label]['model']
-
- log.info('Loading Experimental Data')
- dispersion_fitter = td_dispersion.FastDispersionFitter.from_file(
- str(input_sockets['File Path'])
- )
-
- log.info('Computing Fast Dispersive Fit of Experimental Data...')
- pole_residue_medium, rms_error = dispersion_fitter.fit(
- min_num_poles=self.disp_fit__min_poles,
- max_num_poles=self.disp_fit__max_poles,
- tolerance_rms=self.disp_fit__tolerance_rms,
- )
- log.info('Fit Succeeded w/RMS "%s"!', f'{rms_error:.5f}')
-
- # Populate Cache
- CACHE[self.bl_label] = {}
- CACHE[self.bl_label]['model'] = pole_residue_medium
- CACHE[self.bl_label]['fitter'] = dispersion_fitter
- CACHE[self.bl_label]['rms_error'] = rms_error
-
- return pole_residue_medium
####################
# - Event Methods: Setup Output Socket
####################
@events.on_value_changed(
- socket_name='File Path',
prop_name='tidy3d_type',
- input_sockets={'File Path'},
+ run_on_init=True,
+ # Loaded
props={'tidy3d_type'},
)
- def on_file_changed(self, input_sockets: dict, props: dict):
- if CACHE.get(self.bl_label) is not None:
- del CACHE[self.bl_label]
-
- file_ext = ''.join(input_sockets['File Path'].suffixes)
- if not (
- input_sockets['File Path'].is_file()
- and file_ext in VALID_FILE_EXTS[props['tidy3d_type']]
- ):
- self.loose_output_sockets = {}
- else:
- self.loose_output_sockets = {
- 'SIMULATION_DATA': {
- 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
- },
- 'SIMULATION': {'Sim': sockets.MaxwellFDTDSimSocketDef()},
- 'MEDIUM': {'Medium': sockets.MaxwellMediumSocketDef()},
- 'EXPERIM_DISP_MEDIUM': {
- 'Experim Disp Medium': sockets.MaxwellMediumSocketDef()
- },
- }[props['tidy3d_type']]
+ def on_file_changed(self, props) -> None:
+ self.loose_output_sockets = {
+ ValidTDFileExts.SimData: {
+ 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
+ },
+ ValidTDFileExts.Sim: {'Sim': sockets.MaxwellFDTDSimSocketDef()},
+ ValidTDFileExts.Medium: {'Medium': sockets.MaxwellMediumSocketDef()},
+ }[props['tidy3d_type']]
####################
- # - Event Methods: Plot
+ # - FlowKind.Value
####################
- @events.on_show_plot(
- managed_objs={'plot'},
+ def _compute_td_obj(
+ self, props, input_sockets
+ ) -> td.components.base.Tidy3dBaseModel:
+ tidy3d_type = props['tidy3d_type']
+ file_path = input_sockets['File Path']
+
+ if tidy3d_type.is_path_compatible(file_path):
+ return tidy3d_type.load(file_path)
+ return FS.FlowPending
+
+ @events.computes_output_socket(
+ 'Sim',
+ kind=FK.Value,
+ # Loaded
props={'tidy3d_type'},
+ inscks_kinds={'File Path': FK.Value},
)
- def on_show_plot(
- self,
- props: dict,
- managed_objs: dict,
- ):
- """When the filetype is 'Experimental Dispersive Medium', plot the computed model against the input data."""
- if props['tidy3d_type'] == 'EXPERIM_DISP_MEDIUM':
- # Populate Cache
- if CACHE.get(self.bl_label) is None:
- model_medium = self.compute_experim_disp_medium()
- disp_fitter = CACHE[self.bl_label]['fitter']
- else:
- model_medium = CACHE[self.bl_label]['model']
- disp_fitter = CACHE[self.bl_label]['fitter']
+ def compute_sim(self, props, input_sockets) -> td.Simulation:
+ return self._compute_td_obj(props, input_sockets)
- # Plot
- managed_objs['plot'].mpl_plot_to_image(
- lambda ax: disp_fitter.plot(
- medium=model_medium,
- ax=ax,
- ),
- bl_select=True,
- )
+ @events.computes_output_socket(
+ 'Sim Data',
+ kind=FK.Value,
+ # Loaded
+ props={'tidy3d_type'},
+ inscks_kinds={'File Path': FK.Value},
+ )
+ def compute_sim_data(self, props, input_sockets) -> td.SimulationData:
+ return self._compute_td_obj(props, input_sockets)
+
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Value,
+ # Loaded
+ props={'tidy3d_type'},
+ inscks_kinds={'File Path': FK.Value},
+ )
+ def compute_medium(self, props, input_sockets: dict) -> td.Medium:
+ return self._compute_td_obj(props, input_sockets)
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/__init__.py
index ff8cee9..a58446a 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/__init__.py
@@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from . import tidy3d_web_runner
+from . import tidy3d_web_importer
BL_REGISTER = [
- *tidy3d_web_runner.BL_REGISTER,
+ *tidy3d_web_importer.BL_REGISTER,
]
BL_NODES = {
- **tidy3d_web_runner.BL_NODES,
+ **tidy3d_web_importer.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_importer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_importer.py
new file mode 100644
index 0000000..8550a58
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_importer.py
@@ -0,0 +1,217 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements `Tidy3DWebImporterNode`."""
+
+import typing as typ
+
+import bpy
+import tidy3d as td
+
+from blender_maxwell.services import tdcloud
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
+SimDataArray: typ.TypeAlias = dict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.SimulationData
+]
+SimDataArrayInfo: typ.TypeAlias = dict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], typ.Any
+]
+
+
+####################
+# - Node
+####################
+class Tidy3DWebImporterNode(base.MaxwellSimNode):
+ """Retrieve a simulation w/data from the Tidy3D cloud service."""
+
+ node_type = ct.NodeType.Tidy3DWebImporter
+ bl_label = 'Tidy3D Web Importer'
+
+ input_sockets: typ.ClassVar = {
+ 'Preview Sim': sockets.MaxwellFDTDSimSocketDef(),
+ }
+ input_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ should_exist=True,
+ ),
+ },
+ 'Batch': {
+ 'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Array,
+ should_exist=True,
+ ),
+ },
+ }
+ output_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
+ },
+ 'Batch': {
+ 'Sim Datas': sockets.MaxwellFDTDSimDataSocketDef(
+ active_kind=FK.Array,
+ ),
+ },
+ }
+
+ ####################
+ # - Properties: Cloud Tasks -> Sim Datas
+ ####################
+ @events.on_value_changed(
+ # Trigger
+ socket_name={'Cloud Task': FK.Value, 'Cloud Tasks': FK.Array},
+ )
+ def on_cloud_tasks_changed(self) -> None: # noqa: D102
+ self.cloud_tasks = bl_cache.Signal.InvalidateCache
+
+ @bl_cache.cached_bl_property(depends_on={'active_socket_set'})
+ def cloud_tasks(self) -> list[tdcloud.CloudTask] | None:
+ """Retrieve the current cloud tasks from the input.
+
+ If one can't be loaded, return None.
+ """
+ if self.active_socket_set == 'Single':
+ cloud_task_single = self._compute_input(
+ 'Cloud Task',
+ kind=FK.Value,
+ )
+ has_cloud_task_single = not FS.check(cloud_task_single)
+ if has_cloud_task_single:
+ return [cloud_task_single]
+
+ if self.active_socket_set == 'Batch':
+ cloud_task_array = self._compute_input(
+ 'Cloud Tasks',
+ kind=FK.Array,
+ )
+ has_cloud_task_array = not FS.check(cloud_task_array)
+ if has_cloud_task_array:
+ return cloud_task_array
+
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
+ def task_infos(self) -> list[tdcloud.CloudTaskInfo | None] | None:
+ """Retrieve the current cloud task information from the input socket.
+
+ If it can't be loaded, return None.
+ """
+ if self.cloud_tasks is not None:
+ task_infos = [
+ tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
+ for cloud_task in self.cloud_tasks
+ ]
+ if task_infos:
+ return task_infos
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'cloud_tasks', 'task_infos'})
+ def sim_datas(self) -> SimDataArray | None:
+ """Retrieve the simulation data of the current cloud task from the input socket.
+
+ If it can't be loaded, return None.
+ """
+ cloud_tasks = self.cloud_tasks
+ task_infos = self.task_infos
+ if (
+ cloud_tasks is not None
+ and task_infos is not None
+ and all(task_info is not None for task_info in task_infos)
+ ):
+ sim_datas = {}
+ for cloud_task, task_info in [
+ (cloud_task, task_info)
+ for cloud_task, task_info in zip(cloud_tasks, task_infos, strict=True)
+ if task_info.status == 'success'
+ ]:
+ sim_data = tdcloud.TidyCloudTasks.download_task_sim_data(
+ cloud_task,
+ task_info.disk_cache_path(ct.addon.prefs().addon_cache_path),
+ )
+ if sim_data is not None:
+ sim_metadata = ct.SimMetadata.from_sim(sim_data)
+ sim_datas |= {sim_metadata.syms_vals: sim_data}
+
+ if sim_datas:
+ return sim_datas
+ return None
+
+ ####################
+ # - UI
+ ####################
+ def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ """Draw information about the cloud connection."""
+ tdcloud.draw_cloud_status(layout)
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Sim Data',
+ kind=FK.Value,
+ # Loaded
+ props={'sim_datas'},
+ )
+ def compute_sim_data(self, props) -> td.SimulationData | FS:
+ """A single simulation data object, when there only is one."""
+ sim_datas = props['sim_datas']
+
+ if sim_datas is not None and len(sim_datas) == 1:
+ return next(iter(sim_datas.values()))
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Array
+ ####################
+ @events.computes_output_socket(
+ 'Sim Datas',
+ kind=FK.Array,
+ # Loaded
+ props={'sim_datas'},
+ )
+ def compute_sim_datas(self, props) -> SimDataArray | FS:
+ """All simulation data objects, for when there are more than one.
+
+ Generally part of the same batch.
+ """
+ sim_datas = props['sim_datas']
+
+ if sim_datas is not None and len(sim_datas) > 1:
+ return sim_datas
+ return FS.FlowPending
+
+
+####################
+# - Blender Registration
+####################
+BL_REGISTER = [
+ Tidy3DWebImporterNode,
+]
+BL_NODES = {
+ ct.NodeType.Tidy3DWebImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_WEBIMPORTERS)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_runner.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_runner.py
deleted file mode 100644
index 6b40a3f..0000000
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/web_importers/tidy3d_web_runner.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# blender_maxwell
-# Copyright (C) 2024 blender_maxwell Project Contributors
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-import typing as typ
-
-import bpy
-import tidy3d as td
-
-from blender_maxwell.services import tdcloud
-from blender_maxwell.utils import bl_cache, logger
-
-from .... import contracts as ct
-from .... import sockets
-from ... import base, events
-
-log = logger.get(__name__)
-
-
-####################
-# - Operators
-####################
-class RunSimulation(bpy.types.Operator):
- """Run a Tidy3D simulation accessible from a `Tidy3DWebImporterNode`."""
-
- bl_idname = ct.OperatorType.NodeRunSimulation
- bl_label = 'Run Sim'
- bl_description = 'Run the currently tracked simulation task'
-
- @classmethod
- def poll(cls, context):
- return (
- # Check Tidy3D Cloud
- tdcloud.IS_AUTHENTICATED
- # Check Tidy3DWebImporterNode is Accessible
- and hasattr(context, 'node')
- and hasattr(context.node, 'node_type')
- and context.node.node_type == ct.NodeType.Tidy3DWebImporter
- # Check Task is Runnable
- and context.node.is_task_runnable
- )
-
- def execute(self, context):
- node = context.node
- node.cloud_task.submit()
-
- return {'FINISHED'}
-
-
-class ReloadTrackedTask(bpy.types.Operator):
- """Reload information of the selected task in a `Tidy3DWebImporterNode`."""
-
- bl_idname = ct.OperatorType.NodeReloadTrackedTask
- bl_label = 'Reload Tracked Tidy3D Cloud Task'
- bl_description = 'Reload the currently tracked simulation task'
-
- @classmethod
- def poll(cls, context):
- return (
- # Check Tidy3D Cloud
- tdcloud.IS_AUTHENTICATED
- # Check Tidy3DWebImporterNode is Accessible
- and hasattr(context, 'node')
- and hasattr(context.node, 'node_type')
- and context.node.node_type == ct.NodeType.Tidy3DWebImporter
- )
-
- def execute(self, context):
- node = context.node
- tdcloud.TidyCloudTasks.update_task(node.cloud_task)
-
- return {'FINISHED'}
-
-
-####################
-# - Node
-####################
-class Tidy3DWebImporterNode(base.MaxwellSimNode):
- node_type = ct.NodeType.Tidy3DWebImporter
- bl_label = 'Tidy3D Web Runner/Importer'
-
- input_sockets: typ.ClassVar = {
- 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
- should_exist=True, ## Ensure it is never NewSimCloudTask
- ),
- }
- output_sockets: typ.ClassVar = {
- 'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
- }
-
- ####################
- # - Computed (Cached)
- ####################
- @property
- def cloud_task(self) -> tdcloud.CloudTask | None:
- """Retrieve the current cloud task from the input socket.
-
- If one can't be loaded, return None.
- """
- cloud_task = self._compute_input(
- 'Cloud Task',
- kind=ct.FlowKind.Value,
- )
- has_cloud_task = not ct.FlowSignal.check(cloud_task)
-
- if has_cloud_task:
- return cloud_task
- return None
-
- @property
- def task_info(self) -> tdcloud.CloudTaskInfo | None:
- """Retrieve the current cloud task information from the input socket.
-
- If it can't be loaded, return None.
- """
- cloud_task = self.cloud_task
- if cloud_task is None:
- return None
-
- # Retrieve Task Info
- task_info = tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
- if task_info is None:
- return None
-
- return task_info
-
- @property
- def sim_data(self) -> td.Simulation | None:
- """Retrieve the simulation data of the current cloud task from the input socket.
-
- If it can't be loaded, return None.
- """
- task_info = self.task_info
- if task_info is None:
- return None
-
- if task_info.status == 'success':
- # Download Sim Data
- ## -> self.cloud_task really shouldn't be able to be None here.
- ## -> So, we check it by applying the Ostrich method.
- sim_data = tdcloud.TidyCloudTasks.download_task_sim_data(
- self.cloud_task,
- tdcloud.TidyCloudTasks.task_info(
- self.cloud_task.task_id
- ).disk_cache_path(ct.addon.prefs().addon_cache_path),
- )
- if sim_data is None:
- return None
-
- return sim_data
-
- return None
-
- ####################
- # - Computed (Uncached)
- ####################
- @property
- def is_task_runnable(self) -> bool:
- """Checks whether all conditions are satisfied to be able to actually run a simulation."""
- if self.task_info is not None:
- return self.task_info.status == 'draft'
- return False
-
- ####################
- # - UI
- ####################
- def draw_operators(self, context, layout):
- # Row: Run Sim Buttons
- row = layout.row(align=True)
- row.operator(
- ct.OperatorType.NodeRunSimulation,
- text='Run Sim',
- )
-
- def draw_info(self, context, layout):
- # Connection Info
- auth_icon = 'CHECKBOX_HLT' if tdcloud.IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
- conn_icon = 'CHECKBOX_HLT' if tdcloud.IS_ONLINE else 'CHECKBOX_DEHLT'
-
- row = layout.row()
- row.alignment = 'CENTER'
- row.label(text='Cloud Status')
- box = layout.box()
- split = box.split(factor=0.85)
-
- ## Split: Left Column
- col = split.column(align=False)
- col.label(text='Authed')
- col.label(text='Connected')
-
- ## Split: Right Column
- col = split.column(align=False)
- col.label(icon=auth_icon)
- col.label(icon=conn_icon)
-
- # Cloud Task Info
- if self.task_info is not None:
- # Header
- row = layout.row()
- row.alignment = 'CENTER'
- row.label(text='Task Info')
-
- # Task Run Progress
- # row = layout.row(align=True)
- # row.progress(
- # factor=0.0,
- # type='BAR',
- # text=f'Status: {self.task_info.status.capitalize()}',
- # )
- row.operator(
- ct.OperatorType.NodeReloadTrackedTask,
- text='',
- icon='FILE_REFRESH',
- )
-
- # Task Information
- box = layout.box()
- split = box.split(factor=0.4)
-
- ## Split: Left Column
- col = split.column(align=False)
- col.label(text='Status')
- col.label(text='Real Cost')
-
- ## Split: Right Column
- cost_real = (
- f'{self.task_info.cost_real:.2f}'
- if self.task_info.cost_real is not None
- else 'TBD'
- )
-
- col = split.column(align=False)
- col.alignment = 'RIGHT'
- col.label(text=self.task_info.status.capitalize())
- col.label(text=f'{cost_real} creds')
-
- ####################
- # - Output Methods
- ####################
- @events.computes_output_socket(
- 'Sim Data',
- props={'sim_data'},
- input_sockets={'Cloud Task'}, ## Keep to respect dependency chains.
- )
- def compute_sim_data(
- self, props, input_sockets
- ) -> td.SimulationData | ct.FlowSignal:
- if props['sim_data'] is None:
- return ct.FlowSignal.FlowPending
-
- return props['sim_data']
-
-
-####################
-# - Blender Registration
-####################
-BL_REGISTER = [
- RunSimulation,
- ReloadTrackedTask,
- Tidy3DWebImporterNode,
-]
-BL_NODES = {
- ct.NodeType.Tidy3DWebImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_WEBIMPORTERS)
-}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/kitchen_sink.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/kitchen_sink.py
deleted file mode 100644
index a3d40d8..0000000
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/kitchen_sink.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# blender_maxwell
-# Copyright (C) 2024 blender_maxwell Project Contributors
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-import sympy as sp
-
-from .. import contracts as ct
-from .. import sockets
-from . import base
-
-
-class KitchenSinkNode(base.MaxwellSimNode):
- node_type = ct.NodeType.KitchenSink
-
- bl_label = 'Kitchen Sink'
- # bl_icon = ...
-
- ####################
- # - Sockets
- ####################
- input_sockets = {
- 'Static Data': sockets.AnySocketDef(),
- }
- input_socket_sets = {
- 'Basic': {
- 'Any': sockets.AnySocketDef(),
- 'Bool': sockets.BoolSocketDef(),
- 'FilePath': sockets.FilePathSocketDef(),
- 'Text': sockets.TextSocketDef(),
- },
- 'Number': {
- 'Integer': sockets.IntegerNumberSocketDef(),
- 'Rational': sockets.RationalNumberSocketDef(),
- 'Real': sockets.RealNumberSocketDef(),
- 'Complex': sockets.ComplexNumberSocketDef(),
- },
- 'Vector': {
- 'Real 2D': sockets.Real2DVectorSocketDef(),
- 'Real 3D': sockets.Real3DVectorSocketDef(
- default_value=sp.Matrix([0.0, 0.0, 0.0])
- ),
- 'Complex 2D': sockets.Complex2DVectorSocketDef(),
- 'Complex 3D': sockets.Complex3DVectorSocketDef(),
- },
- 'Physical': {
- 'Time': sockets.PhysicalTimeSocketDef(),
- # "physical_point_2d": sockets.PhysicalPoint2DSocketDef(),
- 'Angle': sockets.PhysicalAngleSocketDef(),
- 'Length': sockets.PhysicalLengthSocketDef(),
- 'Area': sockets.PhysicalAreaSocketDef(),
- 'Volume': sockets.PhysicalVolumeSocketDef(),
- 'Point 3D': sockets.PhysicalPoint3DSocketDef(),
- ##"physical_size_2d": sockets.PhysicalSize2DSocketDef(),
- 'Size 3D': sockets.PhysicalSize3DSocketDef(),
- 'Mass': sockets.PhysicalMassSocketDef(),
- 'Speed': sockets.PhysicalSpeedSocketDef(),
- 'Accel Scalar': sockets.PhysicalAccelScalarSocketDef(),
- 'Force Scalar': sockets.PhysicalForceScalarSocketDef(),
- # "physical_accel_3dvector": sockets.PhysicalAccel3DVectorSocketDef(),
- ##"physical_force_3dvector": sockets.PhysicalForce3DVectorSocketDef(),
- 'Pol': sockets.PhysicalPolSocketDef(),
- 'Freq': sockets.PhysicalFreqSocketDef(),
- },
- 'Blender': {
- 'Object': sockets.BlenderObjectSocketDef(),
- 'Collection': sockets.BlenderCollectionSocketDef(),
- 'Image': sockets.BlenderImageSocketDef(),
- 'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
- 'Text': sockets.BlenderTextSocketDef(),
- },
- 'Maxwell': {
- 'Source': sockets.MaxwellSourceSocketDef(),
- 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
- 'Medium': sockets.MaxwellMediumSocketDef(),
- 'Medium Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(),
- 'Structure': sockets.MaxwellStructureSocketDef(),
- 'Bound Box': sockets.MaxwellBoundBoxSocketDef(),
- 'Bound Face': sockets.MaxwellBoundFaceSocketDef(),
- 'Monitor': sockets.MaxwellMonitorSocketDef(),
- 'FDTD Sim': sockets.MaxwellFDTDSimSocketDef(),
- 'Sim Grid': sockets.MaxwellSimGridSocketDef(),
- 'Sim Grid Axis': sockets.MaxwellSimGridAxisSocketDef(),
- },
- }
-
- output_sockets = {
- 'Static Data': sockets.AnySocketDef(),
- }
- output_socket_sets = input_socket_sets
-
-
-####################
-# - Blender Registration
-####################
-BL_REGISTER = [
- KitchenSinkNode,
-]
-BL_NODES = {ct.NodeType.KitchenSink: (ct.NodeCategory.MAXWELLSIM_INPUTS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/__init__.py
index 42adff0..7b07c8e 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/__init__.py
@@ -14,20 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from . import library_medium
-
# from . import pec_medium
# from . import isotropic_medium
# from . import anisotropic_medium
#
# from . import triple_sellmeier_medium
# from . import sellmeier_medium
-# from . import pole_residue_medium
# from . import drude_medium
# from . import drude_lorentz_medium
# from . import debye_medium
#
-# from . import non_linearities
+from . import library_medium, non_linearities, pole_residue_medium
BL_REGISTER = [
*library_medium.BL_REGISTER,
@@ -37,12 +34,12 @@ BL_REGISTER = [
#
# *triple_sellmeier_medium.BL_REGISTER,
# *sellmeier_medium.BL_REGISTER,
- # *pole_residue_medium.BL_REGISTER,
+ *pole_residue_medium.BL_REGISTER,
# *drude_medium.BL_REGISTER,
# *drude_lorentz_medium.BL_REGISTER,
# *debye_medium.BL_REGISTER,
#
- # *non_linearities.BL_REGISTER,
+ *non_linearities.BL_REGISTER,
]
BL_NODES = {
**library_medium.BL_NODES,
@@ -52,10 +49,10 @@ BL_NODES = {
#
# **triple_sellmeier_medium.BL_NODES,
# **sellmeier_medium.BL_NODES,
- # **pole_residue_medium.BL_NODES,
+ **pole_residue_medium.BL_NODES,
# **drude_medium.BL_NODES,
# **drude_lorentz_medium.BL_NODES,
# **debye_medium.BL_NODES,
#
- # **non_linearities.BL_NODES,
+ **non_linearities.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
index e9c275b..466b005 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/library_medium.py
@@ -14,7 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `LibraryMediumNode`."""
+
import enum
+import functools
import typing as typ
import bpy
@@ -36,8 +39,14 @@ log = logger.get(__name__)
_mat_lib_iter = iter(td.material_library)
_mat_key = ''
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class VendoredMedium(enum.StrEnum):
+ """Static enum of all mediums vendored with the Tidy3D client library."""
+
# Declare StrEnum of All Tidy3D Mediums
## -> This is a 'for ... in ...', which uses globals as loop variables.
## -> It's a bit of a hack, but very effective.
@@ -53,16 +62,23 @@ class VendoredMedium(enum.StrEnum):
@staticmethod
def to_name(v: typ.Self) -> str:
+ """UI method to get the name of a vendored medium."""
return td.material_library[v].name
+ @functools.cached_property
+ def name(self) -> str:
+ """Name of the vendored medium."""
+ return VendoredMedium.to_name(self)
+
@staticmethod
def to_icon(_: typ.Self) -> str:
+ """No icon."""
return ''
####################
# - Medium Properties
####################
- @property
+ @functools.cached_property
def tidy3d_medium_item(self) -> Tidy3DMediumItem:
"""Extracts the Tidy3D "Medium Item", which encapsulates all the provided experimental variants."""
return td.material_library[self]
@@ -70,12 +86,12 @@ class VendoredMedium(enum.StrEnum):
####################
# - Medium Variant Properties
####################
- @property
+ @functools.cached_property
def medium_variants(self) -> set[Tidy3DMediumVariant]:
"""Extracts the list of medium variants, each corresponding to a particular experiment in the literature."""
return self.tidy3d_medium_item.variants
- @property
+ @functools.cached_property
def default_medium_variant(self) -> Tidy3DMediumVariant:
"""Extracts the "default" medium variant, as selected by Tidy3D."""
return self.medium_variants[self.tidy3d_medium_item.default]
@@ -83,7 +99,7 @@ class VendoredMedium(enum.StrEnum):
####################
# - Enum Helper
####################
- @property
+ @functools.cached_property
def variants_as_bl_enum_elements(self) -> list[ct.BLEnumElement]:
"""Computes a list of variants in a format suitable for use in a dynamic `EnumProperty`.
@@ -105,6 +121,8 @@ class VendoredMedium(enum.StrEnum):
class LibraryMediumNode(base.MaxwellSimNode):
+ """A pre-defined medium sourced from a particular experiment in the literature."""
+
node_type = ct.NodeType.LibraryMedium
bl_label = 'Library Medium'
@@ -113,7 +131,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {}
output_sockets: typ.ClassVar = {
- 'Medium': sockets.MaxwellMediumSocketDef(),
+ 'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Valid Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq,
@@ -172,7 +190,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
"""
return spu.convert_to(
- sp.Matrix([sp.nsimplify(el) for el in self.medium.frequency_range])
+ sp.ImmutableMatrix([sp.nsimplify(el) for el in self.medium.frequency_range])
* spu.hertz,
spux.terahertz,
)
@@ -180,7 +198,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'freq_range'})
def wl_range(self) -> sp.Expr:
"""Deduce the vacuum wavelength range as a unit-aware (nanometer, for convenience) column vector."""
- return sp.Matrix(
+ return sp.ImmutableMatrix(
self.freq_range.applyfunc(
lambda el: spu.convert_to(
sci_constants.vac_speed_of_light / el, spu.nanometer
@@ -216,13 +234,16 @@ class LibraryMediumNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self) -> str:
+ """Show the active medium in the node label."""
return f'Medium: {self.vendored_medium}'
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Dropdowns for a medium, and a particular experimental variant."""
layout.prop(self, self.blfields['vendored_medium'], text='')
layout.prop(self, self.blfields['variant_name'], text='')
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
+ """Draw information about a perticular variant of a particular medium."""
box = col.box()
row = box.row(align=True)
@@ -243,48 +264,81 @@ class LibraryMediumNode(base.MaxwellSimNode):
box.operator('wm.url_open', text='Link to Data').url = self.data_url
####################
- # - Output
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Medium',
+ kind=FK.Value,
+ # Loaded
props={'medium'},
)
- def compute_medium(self, props) -> sp.Expr:
- return props['medium']
+ def compute_medium_value(self, props) -> td.Medium | FS:
+ """Directly produce the medium."""
+ medium = props['medium']
+ if medium is not None:
+ return medium
+ return FS.FlowSignal
+ ####################
+ # - FlowKind.Func
+ ####################
@events.computes_output_socket(
- 'Valid Freqs',
- kind=ct.FlowKind.Range,
- props={'freq_range'},
+ 'Medium',
+ kind=FK.Func,
+ # Loaded
+ outscks_kinds={'Medium': FK.Value},
)
- def compute_valid_freqs_lazy(self, props) -> sp.Expr:
- return ct.RangeFlow(
- start=spu.scale_to_unit(['freq_range'][0], spux.THz),
- stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
- scaling=ct.ScalingMode.Lin,
- unit=spux.THz,
- )
+ def compute_medium_func(self, output_sockets) -> ct.FuncFlow:
+ """Simply bake `Value` into a function."""
+ return ct.FuncFlow(func=lambda: output_sockets['Medium'])
+ ####################
+ # - FlowKind.Params
+ ####################
@events.computes_output_socket(
- 'Valid WLs',
- kind=ct.FlowKind.Range,
- props={'wl_range'},
+ 'Medium',
+ kind=FK.Params,
)
- def compute_valid_wls_lazy(self, props) -> sp.Expr:
- return ct.RangeFlow(
- start=spu.scale_to_unit(['wl_range'][0], spu.nm),
- stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
- scaling=ct.ScalingMode.Lin,
- unit=spu.nm,
- )
+ def compute_medium_params(self) -> ct.ParamsFlow:
+ """Return empty parameters for completeness."""
+ return ct.ParamsFlow()
+
+ ####################
+ # - FlowKind.Range
+ ####################
+ # @events.computes_output_socket(
+ # 'Valid Freqs',
+ # kind=ct.FlowKind.Range,
+ # props={'freq_range'},
+ # )
+ # def compute_valid_freqs_lazy(self, props) -> sp.Expr:
+ # return ct.RangeFlow(
+ # start=spu.scale_to_unit(['freq_range'][0], spux.THz),
+ # stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
+ # scaling=ct.ScalingMode.Lin,
+ # unit=spux.THz,
+ # )
+
+ # @events.computes_output_socket(
+ # 'Valid WLs',
+ # kind=ct.FlowKind.Range,
+ # props={'wl_range'},
+ # )
+ # def compute_valid_wls_lazy(self, props) -> sp.Expr:
+ # return ct.RangeFlow(
+ # start=spu.scale_to_unit(['wl_range'][0], spu.nm),
+ # stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
+ # scaling=ct.ScalingMode.Lin,
+ # unit=spu.nm,
+ # )
####################
# - Preview
####################
+ ## TODO: Move medium preview to a viz node of some kind
@events.on_show_plot(
managed_objs={'plot'},
props={'medium'},
- stop_propagation=True,
)
def on_show_plot(
self,
@@ -293,7 +347,9 @@ class LibraryMediumNode(base.MaxwellSimNode):
):
managed_objs['plot'].mpl_plot_to_image(
lambda ax: props['medium'].plot(props['medium'].frequency_range, ax=ax),
- bl_select=True,
+ width_inches=6.0,
+ height_inches=3.0,
+ dpi=150,
)
## TODO: Plot based on Wl, not freq.
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/__init__.py
index dc6e340..0ab0759 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/__init__.py
@@ -18,18 +18,17 @@ from . import (
add_non_linearity,
chi_3_susceptibility_non_linearity,
kerr_non_linearity,
- two_photon_absorption_non_linearity,
)
BL_REGISTER = [
*add_non_linearity.BL_REGISTER,
*chi_3_susceptibility_non_linearity.BL_REGISTER,
*kerr_non_linearity.BL_REGISTER,
- *two_photon_absorption_non_linearity.BL_REGISTER,
+ # *two_photon_absorption_non_linearity.BL_REGISTER,
]
BL_NODES = {
**add_non_linearity.BL_NODES,
**chi_3_susceptibility_non_linearity.BL_NODES,
**kerr_non_linearity.BL_NODES,
- **two_photon_absorption_non_linearity.BL_NODES,
+ # **two_photon_absorption_non_linearity.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/add_non_linearity.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/add_non_linearity.py
index 5b7d824..a6419c0 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/add_non_linearity.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/add_non_linearity.py
@@ -14,8 +14,179 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `AddNonLinearity`."""
+
+import functools
+import typing as typ
+
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
+
+class AddNonLinearity(base.MaxwellSimNode):
+ """Add non-linearities to a medium, increasing the range of effects that it can encapsulate."""
+
+ node_type = ct.NodeType.AddNonLinearity
+ bl_label = 'Add Non-Linearity'
+
+ input_sockets: typ.ClassVar = {
+ 'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
+ 'Iterations': sockets.ExprSocketDef(
+ mathtype=MT.Integer,
+ default_value=5,
+ abs_min=1,
+ ),
+ }
+ output_sockets: typ.ClassVar = {
+ 'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
+ }
+
+ ####################
+ # - Events
+ ####################
+ @events.on_value_changed(
+ any_loose_input_socket=True,
+ run_on_init=True,
+ )
+ def on_inputs_changed(self) -> None:
+ """Always create one extra loose input socket off the end of the last linked loose socket."""
+ # Deduce SocketDef
+ ## -> Cheat by retrieving the class from the output sockets.
+ SocketDef = sockets.MaxwellMediumNonLinearitySocketDef
+
+ ## TODO: Move this code to events, so it can be shared w/Combine
+ # Deduce Current "Filled"
+ ## -> The first linked socket from the end bounds the "filled" region.
+ ## -> The length of that region, plus one, will be the new amount.
+ total_loose_inputs = len(self.loose_input_sockets)
+
+ reverse_linked_idxs = [
+ i
+ for i, bl_socket in enumerate(reversed(self.inputs.values()))
+ if i < total_loose_inputs and bl_socket.is_linked
+ ]
+ current_filled = total_loose_inputs - (
+ reverse_linked_idxs[0] if reverse_linked_idxs else total_loose_inputs
+ )
+ new_amount = current_filled + 1
+
+ # Deduce SocketDef | Current Amount
+ self.loose_input_sockets = {
+ '#0': SocketDef(active_kind=FK.Func),
+ } | {f'#{i}': SocketDef(active_kind=FK.Func) for i in range(1, new_amount)}
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Medium': {FK.Func, FK.Params},
+ },
+ )
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
+ value = events.realize_known(output_sockets['Medium'])
+ if value is not None:
+ return value
+ return ct.FlowSignal.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'Medium': FK.Func,
+ 'Iterations': FK.Func,
+ },
+ all_loose_input_sockets=True,
+ loose_input_sockets_kind=FK.Func,
+ )
+ def compute_func(self, input_sockets, loose_input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
+ medium = input_sockets['Medium']
+ iterations = input_sockets['Iterations']
+
+ funcs = [
+ non_linearity
+ for non_linearity in loose_input_sockets.values()
+ if not FS.check(non_linearity)
+ ]
+
+ if funcs:
+ non_linearities = functools.reduce(
+ lambda a, b: a | b,
+ funcs,
+ )
+
+ return (medium | iterations | non_linearities).compose_within(
+ lambda els: els[0].updated_copy(
+ nonlinear_spec=td.NonlinearSpec(
+ num_iters=els[1],
+ models=els[2] if isinstance(els[2], tuple) else [els[2]],
+ )
+ )
+ )
+ return ct.FlowSignal.FlowPending
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=ct.FlowKind.Params,
+ # Loaded
+ inscks_kinds={
+ 'Medium': FK.Params,
+ 'Iterations': FK.Params,
+ },
+ all_loose_input_sockets=True,
+ loose_input_sockets_kind=FK.Params,
+ )
+ def compute_params(self, input_sockets, loose_input_sockets) -> td.Box:
+ """Aggregate the function parameters needed by the box."""
+ medium = input_sockets['Medium']
+ iterations = input_sockets['Iterations']
+
+ funcs = [
+ non_linearity
+ for non_linearity in loose_input_sockets.values()
+ if not FS.check(non_linearity)
+ ]
+
+ if funcs:
+ non_linearities = functools.reduce(
+ lambda a, b: a | b,
+ funcs,
+ )
+
+ return medium | iterations | non_linearities
+ return ct.FlowSignal.FlowPending
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ AddNonLinearity,
+]
+BL_NODES = {
+ ct.NodeType.AddNonLinearity: (ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/chi_3_susceptibility_non_linearity.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/chi_3_susceptibility_non_linearity.py
index 5b7d824..18fd2a7 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/chi_3_susceptibility_non_linearity.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/chi_3_susceptibility_non_linearity.py
@@ -14,8 +14,132 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `ChiThreeSuscepNonLinearity`."""
+
+import typing as typ
+
+import bpy
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
+
+class ChiThreeSuscepNonLinearity(base.MaxwellSimNode):
+ r"""An instantaneous non-linear susceptibility described by a $\chi_3$ parameter.
+
+ The model of field-component does not permit component interactions; therefore, it is only valid when the electric field is predominantly polarized along an axis-aligned direction.
+ Additionally, strong non-linearities may suffer from divergence issues, since an iterative local method is used to resolve the relation.
+ """
+
+ node_type = ct.NodeType.ChiThreeSuscepNonLinearity
+ bl_label = 'Chi3 Non-Linearity'
+
+ input_sockets: typ.ClassVar = {
+ 'χ₃': sockets.ExprSocketDef(active_kind=FK.Value),
+ }
+ output_sockets: typ.ClassVar = {
+ 'Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(
+ active_kind=FK.Func
+ ),
+ }
+
+ ####################
+ # - UI
+ ####################
+ def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's reported info.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
+ box = layout.box()
+ row = box.row()
+ row.alignment = 'CENTER'
+ row.label(text='Interpretation')
+
+ # Split
+ split = box.split(factor=0.4, align=False)
+
+ ## LHS: Parameter Names
+ col = split.column()
+ col.alignment = 'RIGHT'
+ col.label(text='χ₃:')
+
+ ## RHS: Parameter Units
+ col = split.column()
+ col.label(text='um² / V²')
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Non-Linearity': {FK.Func, FK.Params},
+ },
+ )
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ r"""The value realizes the output function w/output parameters."""
+ value = events.realize_known(output_sockets['Non-Linearity'])
+ if value is not None:
+ return value
+ return ct.FlowSignal.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'χ₃': FK.Func,
+ },
+ )
+ def compute_func(self, input_sockets) -> ct.FuncFlow:
+ r"""The function encloses the $\chi_3$ parameter in the nonlinear susceptibility."""
+ chi_3 = input_sockets['χ₃']
+ return chi_3.compose_within(
+ lambda _chi_3: td.NonlinearSusceptibility(chi3=_chi_3)
+ )
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'χ₃': FK.Params,
+ },
+ )
+ def compute_params(self, input_sockets) -> ct.FuncFlow:
+ r"""The function parameters of the non-linearity are identical to that of the $\chi_3$ parameter."""
+ return input_sockets['χ₃']
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ ChiThreeSuscepNonLinearity,
+]
+BL_NODES = {
+ ct.NodeType.ChiThreeSuscepNonLinearity: (
+ ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES
+ )
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/kerr_non_linearity.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/kerr_non_linearity.py
index 5b7d824..3045ef1 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/kerr_non_linearity.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/non_linearities/kerr_non_linearity.py
@@ -14,8 +14,131 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `KerrNonLinearity`."""
+
+import typing as typ
+
+import bpy
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
+
+class KerrNonLinearity(base.MaxwellSimNode):
+ r"""An instantaneous non-linear susceptibility described by a $\chi_3$ parameter.
+
+ The model of field-component does not permit component interactions; therefore, it is only valid when the electric field is predominantly polarized along an axis-aligned direction.
+ Additionally, strong non-linearities may suffer from divergence issues, since an iterative local method is used to resolve the relation.
+ """
+
+ node_type = ct.NodeType.KerrNonLinearity
+ bl_label = 'Kerr Non-Linearity'
+
+ input_sockets: typ.ClassVar = {
+ 'n₂': sockets.ExprSocketDef(
+ active_kind=FK.Value,
+ mathtype=MT.Complex,
+ ),
+ }
+ output_sockets: typ.ClassVar = {
+ 'Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(
+ active_kind=FK.Func
+ ),
+ }
+
+ ####################
+ # - UI
+ ####################
+ def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's reported info.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
+ box = layout.box()
+ row = box.row()
+ row.alignment = 'CENTER'
+ row.label(text='Interpretation')
+
+ # Split
+ split = box.split(factor=0.4, align=False)
+
+ ## LHS: Parameter Names
+ col = split.column()
+ col.alignment = 'RIGHT'
+ col.label(text='n₂:')
+
+ ## RHS: Parameter Units
+ col = split.column()
+ col.label(text='um² / W')
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Non-Linearity': {FK.Func, FK.Params},
+ },
+ )
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """The value realizes the output function w/output parameters."""
+ value = events.realize_known(output_sockets['Non-Linearity'])
+ if value is not None:
+ return value
+ return ct.FlowSignal.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'n₂': FK.Func,
+ },
+ )
+ def compute_func(self, input_sockets) -> ct.FuncFlow:
+ r"""The function encloses the $\chi_3$ parameter in the nonlinear susceptibility."""
+ n2 = input_sockets['n₂']
+ return n2.compose_within(lambda _n2: td.KerrNonlinearity(n2=_n2))
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Non-Linearity',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'n₂': FK.Params,
+ },
+ )
+ def compute_params(self, input_sockets) -> ct.FuncFlow:
+ r"""The function parameters of the non-linearity are identical to that of the $\chi_3$ parameter."""
+ return input_sockets['n₂']
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ KerrNonLinearity,
+]
+BL_NODES = {
+ ct.NodeType.KerrNonLinearity: (ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/pole_residue_medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/pole_residue_medium.py
index 5b7d824..bc7fb46 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/pole_residue_medium.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/mediums/pole_residue_medium.py
@@ -14,8 +14,319 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import typing as typ
+
+import bpy
+import sympy.physics.units as spu
+import tidy3d as td
+import tidy3d.plugins.dispersion as td_dispersion
+
+from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from ... import contracts as ct
+from ... import managed_objs, sockets
+from .. import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+PT = spux.PhysicalType
+
+VALID_URL_PREFIXES = {
+ 'https://refractiveindex.info',
+}
+
+
+####################
+# - Operators
+####################
+class FitPoleResidueMedium(bpy.types.Operator):
+ """Trigger fitting of a dispersive medium to a Pole-Residue model, and store it on a `PoleResidueMediumnode`."""
+
+ bl_idname = ct.OperatorType.NodeFitDispersiveMedium
+ bl_label = 'Fit Dispersive Medium from Input'
+ bl_description = (
+ 'Fit the dispersive medium specified by the `PoleResidueMediumNode`.'
+ )
+
+ @classmethod
+ def poll(cls, context):
+ return (
+ # Check Tidy3DWebExporter is Accessible
+ hasattr(context, 'node')
+ and hasattr(context.node, 'node_type')
+ and context.node.node_type == ct.NodeType.PoleResidueMedium
+ # Check Medium is Fittable
+ and context.node.fitter is not None
+ and context.node.fitted_medium is None
+ and context.node.fitted_rms_error is None
+ )
+
+ def execute(self, context):
+ node = context.node
+
+ try:
+ pole_residue_medium, rms_error = node.fitter.fit(
+ min_num_poles=node.fit_min_poles,
+ max_num_poles=node.fit_max_poles,
+ tolerance_rms=node.fit_tolerance_rms,
+ )
+
+ except: # noqa: E722
+ self.report(
+ {'ERROR'}, "Couldn't perform PoleResidue data fit - check inputs."
+ )
+ return {'FINISHED'}
+
+ else:
+ node.fitted_medium = pole_residue_medium
+ node.fitted_rms_error = float(rms_error)
+ for bl_socket in node.inputs:
+ bl_socket.trigger_event(ct.FlowEvent.EnableLock)
+
+ return {'FINISHED'}
+
+
+class ReleasePoleResidueFit(bpy.types.Operator):
+ """Release a previous fit of a dispersive medium to a Pole-Residue model, from a `PoleResidueMediumnode`."""
+
+ bl_idname = ct.OperatorType.NodeReleaseDispersiveFit
+ bl_label = 'Release Dispersive Medium fit'
+ bl_description = (
+ 'Release the dispersive medium fit from the `PoleResidueMediumNode`.'
+ )
+
+ @classmethod
+ def poll(cls, context):
+ return (
+ # Check Tidy3DWebExporter is Accessible
+ hasattr(context, 'node')
+ and hasattr(context.node, 'node_type')
+ and context.node.node_type == ct.NodeType.PoleResidueMedium
+ # Check Medium is Fittable
+ and context.node.fitted_medium is not None
+ and context.node.fitted_rms_error is not None
+ )
+
+ def execute(self, context):
+ node = context.node
+
+ node.fitted_medium = None
+ node.fitted_rms_error = None
+ for bl_socket in node.inputs:
+ bl_socket.trigger_event(ct.FlowEvent.DisableLock)
+
+ return {'FINISHED'}
+
+
+####################
+# - Node
+####################
+class PoleResidueMediumNode(base.MaxwellSimNode):
+ """A dispersive medium described by a pole-residue model."""
+
+ node_type = ct.NodeType.PoleResidueMedium
+ bl_label = 'Pole Residue Medium'
+
+ input_socket_sets: typ.ClassVar = {
+ 'Fit URL': {
+ 'URL': sockets.StringSocketDef(),
+ },
+ 'Fit Data': {
+ 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
+ },
+ }
+ output_sockets: typ.ClassVar = {
+ 'Medium': sockets.MaxwellMediumSocketDef(),
+ }
+ managed_obj_types: typ.ClassVar = {
+ 'plot': managed_objs.ManagedBLImage,
+ }
+
+ ####################
+ # - Properties
+ ####################
+ fit_min_poles: int = bl_cache.BLField(1)
+ fit_max_poles: int = bl_cache.BLField(5)
+ fit_tolerance_rms: float = bl_cache.BLField(0.001)
+
+ fitted_medium: td.PoleResidue | None = bl_cache.BLField(None)
+ fitted_rms_error: float | None = bl_cache.BLField(None)
+
+ ## TODO: Bool of whether to fit eps_inf, with conditional choice of eps_inf as socket
+ ## TODO: "AdvanceFastFitterParam" options incl. loss_bounds, weights, show_progress, show_unweighted_rms, relaxed, smooth, logspacing, numiters, passivity_num_iters, and slsqp_constraint_scale
+
+ ####################
+ # - Data Fitting
+ ####################
+ @events.on_value_changed(
+ socket_name={'Expr': {FK.Func, FK.Params, FK.Info}, 'URL': FK.Value},
+ stop_propagation=True,
+ )
+ def on_expr_changed(self) -> None:
+ """Respond to changes in `Func`, `Params`, and `Info` to invalidate `self.fitter`."""
+ self.fitter = bl_cache.Signal.InvalidateCache
+
+ @bl_cache.cached_bl_property(depends_on={'active_socket_set'})
+ def fitter(self) -> td_dispersion.FastDispersionFitter | None:
+ """Compute a `FastDispersionFitter`, which can be used to initiate a data-fit."""
+ match self.active_socket_set:
+ case 'Fit Data':
+ func = self._compute_input('Expr', kind=FK.Func)
+ info = self._compute_input('Expr', kind=FK.Info)
+ params = self._compute_input('Expr', kind=FK.Params)
+
+ has_info = not FS.check(info)
+
+ expr = events.realize_known({FK.Func: func, FK.Params: params})
+ if (
+ expr is not None
+ and has_info
+ and len(info.dims == 1)
+ and info.first_dim
+ and info.first_dim.physical_type is PT.Length
+ ):
+ return td_dispersion.FastDispersionFitter(
+ wvl_um=info.dims[info.first_dim]
+ .rescale_to_unit(spu.micrometer)
+ .values,
+ n_data=expr.real,
+ k_data=expr.imag,
+ )
+ return None
+
+ case 'Fit URL':
+ url = self._compute_input('URL', kind=FK.Value)
+ has_url = not FS.check(url)
+
+ if has_url and any(
+ url.startswith(valid_prefix) for valid_prefix in VALID_URL_PREFIXES
+ ):
+ return None
+ # return td_dispersion.FastDispersionFitter.from_url(url)
+ return None
+
+ raise TypeError
+
+ ####################
+ # - UI
+ ####################
+ def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
+ """Draw loaded properties."""
+ # Fit/Release Operator
+ row = col.row(align=True)
+ row.operator(
+ ct.OperatorType.NodeFitDispersiveMedium,
+ text='Fit Medium',
+ )
+ if self.fitted_medium is not None:
+ row.operator(
+ ct.OperatorType.NodeReleaseDispersiveFit,
+ icon='LOOP_BACK',
+ text='',
+ )
+
+ # Fit Parameters / Fit Info
+ if self.fitted_medium is None:
+ row = col.row(align=True)
+ row.alignment = 'CENTER'
+ row.label(text='min|max|tol')
+
+ col.prop(self, self.blfields['fit_min_poles'])
+ col.prop(self, self.blfields['fit_max_poles'])
+ col.prop(self, self.blfields['fit_tolerance_rms'])
+ else:
+ box = col.box()
+ row = box.row(align=True)
+ row.alignment = 'CENTER'
+ row.label(text='Fit Info')
+
+ split = box.split(factor=0.4)
+
+ col = split.column()
+ row = col.row()
+ row.label(text='RMS Err')
+
+ col = split.column()
+ row = col.row()
+ row.label(text=f'{self.fitted_rms_error:.4f}')
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Value,
+ # Loaded
+ props={'fitted_medium'},
+ )
+ def compute_fitted_medium_value(self, props) -> td.Medium | FS:
+ """Return the fitted medium."""
+ fitted_medium = props['fitted_medium']
+ if fitted_medium is not None:
+ return fitted_medium
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Func,
+ # Loaded
+ outscks_kinds={'Medium': FK.Value},
+ )
+ def compute_fitted_medium_func(self, output_sockets) -> td.Medium | FS:
+ """Return the fitted medium as a function with that medium baked in."""
+ fitted_medium = output_sockets['Medium']
+ return ct.FuncFlow(
+ func=lambda: fitted_medium,
+ )
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Medium',
+ kind=FK.Params,
+ )
+ def compute_fitted_medium_params(self) -> td.Medium | FS:
+ """Declare no function parameters."""
+ return ct.ParamsFlow()
+
+ ####################
+ # - Event Methods: Plot
+ ####################
+ @events.on_show_plot(
+ managed_objs={'plot'},
+ # Loaded
+ props={'fitter', 'fitted_medium'},
+ )
+ def on_show_plot(self, props, managed_objs):
+ """When the filetype is 'Experimental Dispersive Medium', plot the computed model against the input data."""
+ fitter = props['fitter']
+ fitted_medium = props['fitted_medium']
+ if fitter is not None and fitted_medium is not None:
+ managed_objs['plot'].mpl_plot_to_image(
+ lambda ax: props['fitter'].plot(
+ medium=props['fitted_medium'],
+ ax=ax,
+ ),
+ width_inches=6.0,
+ height_inches=3.0,
+ dpi=150,
+ )
+ return FS.FlowPending
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ FitPoleResidueMedium,
+ ReleasePoleResidueFit,
+ PoleResidueMediumNode,
+]
+BL_NODES = {ct.NodeType.PoleResidueMedium: (ct.NodeCategory.MAXWELLSIM_MEDIUMS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
index d9ca2e8..0bf9bb9 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/eh_field_monitor.py
@@ -31,6 +31,11 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
class EHFieldMonitorNode(base.MaxwellSimNode):
"""Node providing for the monitoring of electromagnetic fields within a given planar region or volume."""
@@ -45,26 +50,27 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- physical_type=spux.PhysicalType.Length,
+ physical_type=PT.Length,
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([1, 1, 1]),
+ physical_type=PT.Length,
+ default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
+ abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Integer,
- default_value=sp.Matrix([10, 10, 10]),
- abs_min=0,
+ mathtype=MT.Integer,
+ default_value=sp.ImmutableMatrix([10, 10, 10]),
+ abs_min=1,
),
}
input_socket_sets: typ.ClassVar = {
'Freq Domain': {
'Freqs': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
- physical_type=spux.PhysicalType.Freq,
+ active_kind=FK.Range,
+ physical_type=PT.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
@@ -73,21 +79,22 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
},
'Time Domain': {
't Range': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
- physical_type=spux.PhysicalType.Time,
+ active_kind=FK.Range,
+ physical_type=PT.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
- default_steps=0,
+ default_steps=2,
),
't Stride': sockets.ExprSocketDef(
- mathtype=spux.MathType.Integer,
+ mathtype=MT.Integer,
default_value=100,
+ abs_min=1,
),
},
}
output_sockets: typ.ClassVar = {
- 'Monitor': sockets.MaxwellMonitorSocketDef(active_kind=ct.FlowKind.Func),
+ 'Monitor': sockets.MaxwellMonitorSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@@ -103,6 +110,11 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the field-selector in the node UI.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
layout.prop(self, self.blfields['fields'], expand=True)
####################
@@ -110,186 +122,134 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Monitor',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- output_sockets={'Monitor'},
- output_socket_kinds={'Monitor': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ outscks_kinds={
+ 'Monitor': {FK.Func, FK.Params},
+ },
)
- def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
- """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Monitor'][ct.FlowKind.Func]
- output_params = output_sockets['Monitor'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params, disallow_jax=True)
- return ct.FlowSignal.FlowPending
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Realizes the output function w/output parameters."""
+ value = events.realize_known(output_sockets['Monitor'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'active_socket_set', 'sim_node_name', 'fields'},
- input_sockets={
- 'Center',
- 'Size',
- 'Stride',
- 'Freqs',
- 't Range',
- 't Stride',
+ inscks_kinds={
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Stride': FK.Func,
+ 'Freqs': FK.Func,
+ 't Range': FK.Func,
+ 't Stride': FK.Func,
},
- input_socket_kinds={
- 'Center': ct.FlowKind.Func,
- 'Size': ct.FlowKind.Func,
- 'Stride': ct.FlowKind.Func,
- 'Freqs': ct.FlowKind.Func,
- 't Range': ct.FlowKind.Func,
- 't Stride': ct.FlowKind.Func,
+ input_sockets_optional={'Freqs', 't Range', 't Stride'},
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
+ 'Freqs': ct.UNITS_TIDY3D,
+ 't Range': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.FieldMonitor:
+ """Lazily assembles the FieldMonitor from the input functions."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_stride = not ct.FlowSignal.check(stride)
+ freqs = input_sockets['Freqs']
+ t_range = input_sockets['t Range']
+ t_stride = input_sockets['t Stride']
- if has_center and has_size and has_stride:
- name = props['sim_node_name']
- fields = props['fields']
+ sim_node_name = props['sim_node_name']
+ fields = props['fields']
- common_func_flow = (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | size.scale_to_unit_system(ct.UNITS_TIDY3D)
- | stride
- )
+ common_func_flow = center | size | stride
+ match props['active_socket_set']:
+ case 'Freq Domain' if not FS.check(freqs):
+ return (common_func_flow | freqs).compose_within(
+ lambda els: td.FieldMonitor(
+ name=sim_node_name,
+ center=els[0].flatten().tolist(),
+ size=els[1].flatten().tolist(),
+ interval_space=els[2].flatten().tolist(),
+ freqs=els[3].flatten(),
+ fields=fields,
+ )
+ )
- match props['active_socket_set']:
- case 'Freq Domain':
- freqs = input_sockets['Freqs']
- has_freqs = not ct.FlowSignal.check(freqs)
-
- if has_freqs:
- return (
- common_func_flow
- | freqs.scale_to_unit_system(ct.UNITS_TIDY3D)
- ).compose_within(
- lambda els: td.FieldMonitor(
- center=els[0].flatten().tolist(),
- size=els[1].flatten().tolist(),
- name=name,
- interval_space=els[2].flatten().tolist(),
- freqs=els[3].flatten(),
- fields=fields,
- )
- )
-
- case 'Time Domain':
- t_range = input_sockets['t Range']
- t_stride = input_sockets['t Stride']
-
- has_t_range = not ct.FlowSignal.check(t_range)
- has_t_stride = not ct.FlowSignal.check(t_stride)
-
- if has_t_range and has_t_stride:
- return (
- common_func_flow
- | t_range.scale_to_unit_system(ct.UNITS_TIDY3D)
- | t_stride.scale_to_unit_system(ct.UNITS_TIDY3D)
- ).compose_within(
- lambda els: td.FieldTimeMonitor(
- center=els[0].flatten().tolist(),
- size=els[1].flatten().tolist(),
- name=name,
- interval_space=els[2].flatten().tolist(),
- start=els[3][0],
- stop=els[3][-1],
- interval=els[4],
- fields=fields,
- )
- )
- return ct.FlowSignal.FlowPending
+ case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
+ return (common_func_flow | t_range | t_stride).compose_within(
+ lambda els: td.FieldTimeMonitor(
+ name=sim_node_name,
+ center=els[0].flatten().tolist(),
+ size=els[1].flatten().tolist(),
+ interval_space=els[2].flatten().tolist(),
+ start=els[3][0],
+ stop=els[3][-1],
+ interval=els[4],
+ fields=fields,
+ )
+ )
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Monitor',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Center',
- 'Size',
- 'Stride',
- 'Freqs',
- 't Range',
- 't Stride',
- },
- input_socket_kinds={
- 'Center': ct.FlowKind.Params,
- 'Size': ct.FlowKind.Params,
- 'Stride': ct.FlowKind.Params,
- 'Freqs': ct.FlowKind.Params,
- 't Range': ct.FlowKind.Params,
- 't Stride': ct.FlowKind.Params,
+ inscks_kinds={
+ 'Center': FK.Params,
+ 'Size': FK.Params,
+ 'Stride': FK.Params,
+ 'Freqs': FK.Params,
+ 't Range': FK.Params,
+ 't Stride': FK.Params,
},
+ input_sockets_optional={'Freqs', 't Range', 't Stride'},
)
def compute_params(self, props, input_sockets) -> None:
+ """Lazily assembles the FieldMonitor from the input functions."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_stride = not ct.FlowSignal.check(stride)
+ freqs = input_sockets['Freqs']
+ t_range = input_sockets['t Range']
+ t_stride = input_sockets['t Stride']
- if has_center and has_size and has_stride:
- common_params = center | size | stride
- match props['active_socket_set']:
- case 'Freq Domain':
- freqs = input_sockets['Freqs']
- has_freqs = not ct.FlowSignal.check(freqs)
+ common_params = center | size | stride
+ match props['active_socket_set']:
+ case 'Freq Domain' if not FS.check(freqs):
+ return common_params | freqs
- if has_freqs:
- return common_params | freqs
-
- case 'Time Domain':
- t_range = input_sockets['t Range']
- t_stride = input_sockets['t Stride']
-
- has_t_range = not ct.FlowSignal.check(t_range)
- has_t_stride = not ct.FlowSignal.check(t_stride)
-
- if has_t_range and has_t_stride:
- return common_params | t_range | t_stride
- return ct.FlowSignal.FlowPending
+ case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
+ return common_params | t_range | t_stride
+ return FS.FlowPending
####################
# - Preview
####################
@events.computes_output_socket(
'Monitor',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
- output_sockets={'Monitor'},
- output_socket_kinds={'Monitor': ct.FlowKind.Params},
)
- def compute_previews(self, props, output_sockets):
- output_params = output_sockets['Monitor']
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and not output_params.symbols:
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
- return ct.PreviewsFlow()
+ def compute_previews(self, props):
+ """Mark the monitor as participating in the preview."""
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
@@ -297,32 +257,31 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
run_on_init=True,
# Loaded
managed_objs={'modifier'},
- input_sockets={'Center', 'Size'},
- output_sockets={'Monitor'},
- output_socket_kinds={'Monitor': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
+ },
)
- def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
- center = input_sockets['Center']
- size = input_sockets['Size']
- output_params = output_sockets['Monitor']
+ def on_previewable_changed(self, managed_objs, input_sockets):
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ size = events.realize_preview(input_sockets['Size'])
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_center and has_size and has_output_params and not output_params.symbols:
- # Push Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.MonitorEHField),
- 'unit_system': ct.UNITS_BLENDER,
- 'inputs': {
- 'Size': size,
- },
+ # Push Input Values to GeoNodes Modifier
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.MonitorEHField),
+ 'inputs': {
+ 'Size': size,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
index d40e0ab..ca59062 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/field_power_flux_monitor.py
@@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `PowerFluxMonitorNode`."""
+
+import itertools
import typing as typ
import bpy
@@ -30,6 +33,12 @@ from ... import managed_objs, sockets
from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+ALL_3D_DIRS = {
+ ax + sgn for ax, sgn in itertools.product(set(ct.SimSpaceAxis), set(ct.SimAxisDir))
+}
class PowerFluxMonitorNode(base.MaxwellSimNode):
@@ -50,20 +59,21 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([1, 1, 1]),
+ default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
+ abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Integer,
- default_value=sp.Matrix([10, 10, 10]),
- abs_min=0,
+ default_value=sp.ImmutableMatrix([10, 10, 10]),
+ abs_min=1,
),
}
input_socket_sets: typ.ClassVar = {
'Freq Domain': {
'Freqs': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
+ active_kind=FK.Range,
physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
@@ -73,7 +83,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
},
'Time Domain': {
't Range': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
+ active_kind=FK.Range,
physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond,
default_min=0,
@@ -83,16 +93,15 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
't Stride': sockets.ExprSocketDef(
mathtype=spux.MathType.Integer,
default_value=100,
+ abs_min=1,
),
},
}
- output_socket_sets: typ.ClassVar = {
- 'Freq Domain': {'Freq Monitor': sockets.MaxwellMonitorSocketDef()},
- 'Time Domain': {'Time Monitor': sockets.MaxwellMonitorSocketDef()},
+ output_sockets: typ.ClassVar = {
+ 'Monitor': sockets.MaxwellMonitorSocketDef(),
}
managed_obj_types: typ.ClassVar = {
- 'mesh': managed_objs.ManagedBLMesh,
'modifier': managed_objs.ManagedBLModifier,
}
@@ -109,6 +118,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the properties of the node."""
# 2D Monitor
if 0 in self._compute_input('Size'):
layout.prop(self, self.blfields['direction_2d'], expand=True)
@@ -125,67 +135,174 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
row.prop(self, self.blfields['include_3d_z'], expand=True)
####################
- # - Event Methods: Computation
+ # - FlowKind.Value
####################
@events.computes_output_socket(
- 'Freq Monitor',
- props={'sim_node_name', 'direction_2d'},
- input_sockets={
- 'Center',
- 'Size',
- 'Stride',
- 'Freqs',
- },
- input_socket_kinds={
- 'Freqs': ct.FlowKind.Range,
- },
- unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
- scale_input_sockets={
- 'Center': 'Tidy3DUnits',
- 'Size': 'Tidy3DUnits',
- 'Freqs': 'Tidy3DUnits',
+ 'Monitor',
+ kind=FK.Value,
+ # Loaded
+ output_sockets={'Monitor'},
+ output_socket_kinds={
+ 'Monitor': {FK.Func, FK.Params},
},
)
- def compute_freq_monitor(
- self,
- input_sockets: dict,
- props: dict,
- unit_systems: dict,
- ) -> td.FieldMonitor:
- log.info(
- 'Computing FluxMonitor (name="%s") with center="%s", size="%s"',
- props['sim_node_name'],
- input_sockets['Center'],
- input_sockets['Size'],
- )
- return td.FluxMonitor(
- center=input_sockets['Center'],
- size=input_sockets['Size'],
- name=props['sim_node_name'],
- interval_space=(1, 1, 1),
- freqs=input_sockets['Freqs'].realize_array.values,
- normal_dir=props['direction_2d'].plus_or_minus,
- )
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Realizes the output function w/output parameters."""
+ value = events.realize_known(output_sockets['Monitor'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Monitor',
+ kind=FK.Func,
+ # Loaded
+ props={'active_socket_set', 'sim_node_name', 'direction_2d'},
+ inscks_kinds={
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Stride': FK.Func,
+ 'Freqs': FK.Func,
+ 't Range': FK.Func,
+ 't Stride': FK.Func,
+ },
+ input_sockets_optional={'Freqs', 't Range', 't Stride'},
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
+ 'Freqs': ct.UNITS_TIDY3D,
+ 't Range': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_func(self, input_sockets, props) -> ct.FuncFlow: # noqa: C901
+ """Compute the correct flux monitor."""
+ center = input_sockets['Center']
+ size = input_sockets['Size']
+ stride = input_sockets['Stride']
+ freqs = input_sockets['Freqs']
+ t_range = input_sockets['t Range']
+ t_stride = input_sockets['t Stride']
+
+ sim_node_name = props['sim_node_name']
+ direction_2d = props['direction_2d']
+
+ # 3D Flux Monitor: Computed Excluded Directions
+ ## -> The flux is always recorded outgoing.
+ ## -> However, one can exclude certain faces from participating.
+ include_3d = props['include_3d']
+ include_3d_x = props['include_3d_x']
+ include_3d_y = props['include_3d_y']
+ include_3d_z = props['include_3d_z']
+ excluded_3d = set()
+ if ct.SimSpaceAxis.X in include_3d:
+ if ct.SimAxisDir.Plus in include_3d_x:
+ excluded_3d.add('x+')
+ if ct.SimAxisDir.Minus in include_3d_x:
+ excluded_3d.add('x-')
+
+ if ct.SimSpaceAxis.Y in include_3d:
+ if ct.SimAxisDir.Plus in include_3d_y:
+ excluded_3d.add('y+')
+ if ct.SimAxisDir.Minus in include_3d_y:
+ excluded_3d.add('y-')
+
+ if ct.SimSpaceAxis.Z in include_3d:
+ if ct.SimAxisDir.Plus in include_3d_z:
+ excluded_3d.add('z+')
+ if ct.SimAxisDir.Minus in include_3d_z:
+ excluded_3d.add('z-')
+
+ excluded_3d = tuple(ALL_3D_DIRS - excluded_3d)
+
+ # Compute Monitor
+ common_func = center | size | stride
+ active_socket_set = props['active_socket_set']
+ match active_socket_set:
+ case 'Freq Domain' if not FS.check(freqs):
+ return (common_func | freqs).compose_within(
+ lambda els: td.FluxMonitor(
+ name=sim_node_name,
+ center=els[0],
+ size=els[1],
+ interval_space=els[2],
+ freqs=els[3],
+ normal_dir=direction_2d.plus_or_minus,
+ exclude_surfaces=excluded_3d,
+ )
+ )
+
+ case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
+ return (common_func | t_range | t_stride).compose_within(
+ lambda els: td.FluxTimeMonitor(
+ name=sim_node_name,
+ center=els[0],
+ size=els[1],
+ interval_space=els[2],
+ start=els[3].item(0),
+ stop=els[3].item(1),
+ interval=els[4],
+ normal_dir=direction_2d.plus_or_minus,
+ exclude_surfaces=excluded_3d,
+ )
+ )
+
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Monitor',
+ kind=FK.Params,
+ # Loaded
+ props={'active_socket_set'},
+ inscks_kinds={
+ 'Center': FK.Params,
+ 'Size': FK.Params,
+ 'Stride': FK.Params,
+ 'Freqs': FK.Params,
+ 't Range': FK.Params,
+ 't Stride': FK.Params,
+ },
+ input_sockets_optional={'Freqs', 't Range', 't Stride'},
+ )
+ def compute_params(self, input_sockets, props) -> ct.ParamsFlow:
+ """Compute the function parameters of the monitor."""
+ center = input_sockets['Center']
+ size = input_sockets['Size']
+ stride = input_sockets['Stride']
+
+ freqs = input_sockets['Freqs']
+ t_range = input_sockets['t Range']
+ t_stride = input_sockets['t Stride']
+
+ common_params = center | size | stride
+
+ # Compute Monitor
+ active_socket_set = props['active_socket_set']
+ match active_socket_set:
+ case 'Freq Domain' if not FS.check(freqs):
+ return common_params | freqs
+
+ case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
+ return common_params | t_range | t_stride
+
+ return FS.FlowPending
####################
# - Preview - Changes to Input Sockets
####################
@events.computes_output_socket(
- 'Time Monitor',
- kind=ct.FlowKind.Previews,
+ 'Monitor',
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews_time(self, props):
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
-
- @events.computes_output_socket(
- 'Freq Monitor',
- kind=ct.FlowKind.Previews,
- # Loaded
- props={'sim_node_name'},
- )
- def compute_previews_freq(self, props):
+ """Mark the box structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
@@ -194,27 +311,30 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
prop_name={'direction_2d'},
run_on_init=True,
# Loaded
- managed_objs={'mesh', 'modifier'},
- props={'direction_2d'},
- input_sockets={'Center', 'Size'},
- unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
+ managed_objs={'modifier'},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
scale_input_sockets={
- 'Center': 'BlenderUnits',
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
},
)
- def on_inputs_changed(self, managed_objs, props, input_sockets, unit_systems):
- # Push Input Values to GeoNodes Modifier
+ def on_previewable_changed(self, managed_objs, input_sockets):
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ size = events.realize_preview(input_sockets['Size'])
+
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorPowerFlux),
- 'unit_system': unit_systems['BlenderUnits'],
'inputs': {
- 'Size': input_sockets['Size'],
- 'Direction': props['direction_2d'].true_or_false,
+ 'Size': size,
},
},
- location=input_sockets['Center'],
+ location=center,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
index 85d845b..9fbc911 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/monitors/permittivity_monitor.py
@@ -20,8 +20,8 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
@@ -29,6 +29,11 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
class PermittivityMonitorNode(base.MaxwellSimNode):
"""Provides a bounded 1D/2D/3D recording region for the diagonal of the complex-valued permittivity tensor."""
@@ -43,23 +48,24 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- physical_type=spux.PhysicalType.Length,
+ physical_type=PT.Length,
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([1, 1, 1]),
+ physical_type=PT.Length,
+ default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
+ abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Integer,
- default_value=sp.Matrix([10, 10, 10]),
- abs_min=0,
+ mathtype=MT.Integer,
+ default_value=sp.ImmutableMatrix([10, 10, 10]),
+ abs_min=1,
),
'Freqs': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
- physical_type=spux.PhysicalType.Freq,
+ active_kind=FK.Range,
+ physical_type=PT.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
@@ -67,7 +73,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
- 'Permittivity Monitor': sockets.MaxwellMonitorSocketDef()
+ 'Monitor': sockets.MaxwellMonitorSocketDef(),
}
managed_obj_types: typ.ClassVar = {
@@ -75,57 +81,95 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
}
####################
- # - Output
+ # - FlowKind.Value
####################
@events.computes_output_socket(
- 'Permittivity Monitor',
- props={'sim_node_name'},
- input_sockets={
- 'Center',
- 'Size',
- 'Stride',
- 'Freqs',
- },
- input_socket_kinds={
- 'Freqs': ct.FlowKind.Range,
- },
- unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
- scale_input_sockets={
- 'Center': 'Tidy3DUnits',
- 'Size': 'Tidy3DUnits',
- 'Freqs': 'Tidy3DUnits',
+ 'Monitor',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Monitor': {FK.Func, FK.Params},
},
)
- def compute_permittivity_monitor(
- self,
- input_sockets: dict,
- props: dict,
- unit_systems: dict,
- ) -> td.FieldMonitor:
- log.info(
- 'Computing PermittivityMonitor (name="%s") with center="%s", size="%s"',
- props['sim_node_name'],
- input_sockets['Center'],
- input_sockets['Size'],
- )
- return td.PermittivityMonitor(
- center=input_sockets['Center'],
- size=input_sockets['Size'],
- name=props['sim_node_name'],
- interval_space=tuple(input_sockets['Stride']),
- freqs=input_sockets['Freqs'].realize().values,
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Realizes the output function w/output parameters."""
+ value = events.realize_known(output_sockets['Monitor'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Monitor',
+ kind=FK.Func,
+ # Loaded
+ props={'sim_node_name'},
+ inscks_kinds={
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Stride': FK.Func,
+ 'Freqs': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
+ 'Freqs': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_func(self, props, input_sockets) -> td.FieldMonitor:
+ """Lazily assemble the permittivity monitor from the input functions."""
+ center = input_sockets['Center']
+ size = input_sockets['Size']
+ stride = input_sockets['Stride']
+ freqs = input_sockets['Freqs']
+
+ sim_node_name = props['sim_node_name']
+
+ return (center | size | stride | freqs).compose_within(
+ lambda els: td.PermittivityMonitor(
+ name=sim_node_name,
+ center=els[0].flatten().tolist(),
+ size=els[1].flatten().tolist(),
+ interval_space=els[2].flatten().tolist(),
+ freqs=els[3].flatten(),
+ )
)
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Monitor',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'Center': FK.Params,
+ 'Size': FK.Params,
+ 'Stride': FK.Params,
+ 'Freqs': FK.Params,
+ },
+ )
+ def compute_params(self, input_sockets) -> td.FieldMonitor:
+ center = input_sockets['Center']
+ size = input_sockets['Size']
+ stride = input_sockets['Stride']
+ freqs = input_sockets['Freqs']
+
+ return center | size | stride | freqs
+
####################
# - Preview
####################
@events.computes_output_socket(
'Permittivity Monitor',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews_freq(self, props):
+ """Mark the monitor as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
@@ -134,29 +178,30 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
run_on_init=True,
# Loaded
managed_objs={'modifier'},
- input_sockets={'Center', 'Size'},
- unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
scale_input_sockets={
- 'Center': 'BlenderUnits',
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
},
)
- def on_inputs_changed(
- self,
- managed_objs: dict,
- input_sockets: dict,
- unit_systems: dict,
- ):
+ def on_previewable_changed(self, managed_objs, input_sockets):
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ size = events.realize_preview(input_sockets['Size'])
+
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorPermittivity),
- 'unit_system': unit_systems['BlenderUnits'],
'inputs': {
- 'Size': input_sockets['Size'],
+ 'Size': size,
},
},
- location=input_sockets['Center'],
+ location=center,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
index 0ad5d02..9d51cbe 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/viewer.py
@@ -104,6 +104,7 @@ class ViewerNode(base.MaxwellSimNode):
@events.on_value_changed(
# Trugger
socket_name='Any',
+ prop_name='auto_expr',
# Loaded
props={'auto_expr', 'console_print_kind'},
)
@@ -114,11 +115,7 @@ class ViewerNode(base.MaxwellSimNode):
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
"""
# Invalidate PreviewsFlow
- setattr(
- self,
- 'input_' + ct.FlowKind.Previews.property_name,
- bl_cache.Signal.InvalidateCache,
- )
+ self.input_previews = bl_cache.Signal.InvalidateCache
# Invalidate PreviewsFlow
if props['auto_expr']:
@@ -128,35 +125,35 @@ class ViewerNode(base.MaxwellSimNode):
bl_cache.Signal.InvalidateCache,
)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_capabilities(self) -> ct.CapabilitiesFlow | None:
return self.get_flow(ct.FlowKind.Capabilities)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_previews(self) -> ct.PreviewsFlow | None:
return self.get_flow(ct.FlowKind.Previews, always_load=True)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_value(self) -> ct.ValueFlow | None:
return self.get_flow(ct.FlowKind.Value)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_array(self) -> ct.ArrayFlow | None:
return self.get_flow(ct.FlowKind.Array)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_lazy_range(self) -> ct.RangeFlow | None:
return self.get_flow(ct.FlowKind.Range)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_lazy_func(self) -> ct.FuncFlow | None:
return self.get_flow(ct.FlowKind.Func)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_params(self) -> ct.ParamsFlow | None:
return self.get_flow(ct.FlowKind.Params)
- @bl_cache.cached_bl_property(depends_on={'auto_expr'})
+ @bl_cache.cached_bl_property()
def input_info(self) -> ct.InfoFlow | None:
return self.get_flow(ct.FlowKind.Info)
@@ -214,6 +211,11 @@ class ViewerNode(base.MaxwellSimNode):
if key != 'type'
]
+ # Parse Straight String
+ if isinstance(value, str):
+ lines = value.split('\n')
+ return [[line] for line in lines]
+
return None
####################
@@ -305,6 +307,9 @@ class ViewerNode(base.MaxwellSimNode):
log.info('Printing to Console')
if isinstance(flow, spux.SympyType):
console.print(sp.pretty(flow, use_unicode=True))
+ elif isinstance(flow, ct.ArrayFlow):
+ console.print(flow)
+ console.print(flow.values)
else:
console.print(flow)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/web_exporters/tidy3d_web_exporter.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/web_exporters/tidy3d_web_exporter.py
index 3371d6a..5496c12 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/web_exporters/tidy3d_web_exporter.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/outputs/web_exporters/tidy3d_web_exporter.py
@@ -14,13 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `Tidy3DWebExporter`."""
+
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.services import tdcloud
-from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
from .... import contracts as ct
from .... import sockets
@@ -28,47 +30,30 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+SimArray: typ.TypeAlias = dict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
+]
+SimArrayInfo: typ.TypeAlias = dict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
+]
+
####################
# - Operators
####################
-class RecomputeSimInfo(bpy.types.Operator):
- bl_idname = ct.OperatorType.NodeRecomputeSimInfo
- bl_label = 'Recompute Tidy3D Sim Info'
- bl_description = 'Recompute info for any currently attached sim info'
-
- @classmethod
- def poll(cls, context):
- return (
- # Check Tidy3DWebExporter is Accessible
- hasattr(context, 'node')
- and hasattr(context.node, 'node_type')
- and context.node.node_type == ct.NodeType.Tidy3DWebExporter
- # Check Sim is Available (aka. uploadeable)
- and context.node.sim_info_available
- and context.node.sim_info_invalidated
- )
-
- def execute(self, context):
- node = context.node
-
- # Rehydrate the Cache
- node.total_monitor_data = bl_cache.Signal.InvalidateCache
- node.is_sim_uploadable = bl_cache.Signal.InvalidateCache
-
- # Remove the Invalidation Marker
- ## -> This is OK, since we manually guaranteed that it's available.
- node.sim_info_invalidated = False
- return {'FINISHED'}
-
-
class UploadSimulation(bpy.types.Operator):
+ """Upload the simulation embedded in the `Tidy3DWebExpoerter`."""
+
bl_idname = ct.OperatorType.NodeUploadSimulation
bl_label = 'Upload Tidy3D Simulation'
bl_description = 'Upload the attached (locked) simulation, such that it is ready to run on the Tidy3D cloud'
@classmethod
def poll(cls, context):
+ """Allow running whenever there are simulations to upload."""
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
@@ -77,41 +62,80 @@ class UploadSimulation(bpy.types.Operator):
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebExporter
# Check Sim is Available (aka. uploadeable)
- and context.node.is_sim_uploadable
- and context.node.uploaded_task_id == ''
+ and context.node.sims
+ and context.node.sims_uploadable
+ and not context.node.uploaded_task_ids
)
def execute(self, context):
+ """Upload either a single or a batch of simulations.
+
+ Later retrieval of realized parameter points exist in the form of a realized serializable dictionary attached to the `.attrs` field of any `td.Simulation` object.
+ """
node = context.node
- cloud_task = tdcloud.TidyCloudTasks.mk_task(
- task_name=node.new_cloud_task.task_name,
- cloud_folder=node.new_cloud_task.cloud_folder,
- sim=node.sim,
- verbose=True,
- )
- node.uploaded_task_id = cloud_task.task_id
+
+ if node.base_cloud_task is not None:
+ base_task_name = node.base_cloud_task.task_name
+ base_task_folder = node.base_cloud_task.cloud_folder
+ else:
+ self.report({'ERROR'}, 'No base cloud task name')
+ return {'FINISHED'}
+
+ if node.active_socket_set == 'Single':
+ if len(list(node.sims.values())) == 1:
+ sim = next(iter(node.sims.values()))
+ else:
+ self.report({'ERROR'}, '>1 sims for "Single"-mode sim exporter.')
+ return {'FINISHED'}
+
+ cloud_task = tdcloud.TidyCloudTasks.mk_task(
+ task_name=base_task_name,
+ cloud_folder=base_task_folder,
+ sim=sim,
+ verbose=True,
+ )
+ node.uploaded_task_ids = (cloud_task.task_id,)
+
+ if node.active_socket_set == 'Batch':
+ cloud_tasks = [
+ tdcloud.TidyCloudTasks.mk_task(
+ task_name=base_task_name + f'_{i}',
+ cloud_folder=base_task_folder,
+ sim=sim,
+ verbose=True,
+ )
+ for i, sim in enumerate(node.sims.values())
+ ]
+ node.uploaded_task_ids = tuple(
+ [cloud_task.task_id for cloud_task in cloud_tasks]
+ )
+
return {'FINISHED'}
class ReleaseUploadedTask(bpy.types.Operator):
+ """Release the uploaded simulation embedded in the `Tidy3DWebExpoerter`."""
+
bl_idname = ct.OperatorType.NodeReleaseUploadedTask
bl_label = 'Release Tracked Tidy3D Cloud Task'
bl_description = 'Release the currently tracked simulation task'
@classmethod
def poll(cls, context):
+ """Allow running whenever a particular FDTDSim node is tracking uploaded simulations."""
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebExporter
# Check Sim is Available (aka. uploadeable)
- and context.node.uploaded_task_id != ''
+ and context.node.uploaded_task_ids
)
def execute(self, context):
+ """Invalidate the `.sims` property, triggering reevaluation of all downstream information about the simulation."""
node = context.node
- node.uploaded_task_id = ''
+ node.uploaded_task_ids = ()
return {'FINISHED'}
@@ -119,303 +143,357 @@ class ReleaseUploadedTask(bpy.types.Operator):
# - Node
####################
class Tidy3DWebExporterNode(base.MaxwellSimNode):
+ """Export a simulation to the Tidy3D cloud service, where it can be queried and run."""
+
node_type = ct.NodeType.Tidy3DWebExporter
bl_label = 'Tidy3D Web Exporter'
- input_sockets: typ.ClassVar = {
- 'Sim': sockets.MaxwellFDTDSimSocketDef(),
- 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
- should_exist=False,
- ),
+ input_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Sim': sockets.MaxwellFDTDSimSocketDef(),
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Value,
+ should_exist=False,
+ ),
+ },
+ 'Batch': {
+ 'Sims': sockets.MaxwellFDTDSimSocketDef(
+ active_kind=FK.Array,
+ ),
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Value,
+ should_exist=False,
+ ),
+ },
}
- output_sockets: typ.ClassVar = {
- 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
- should_exist=True,
- ),
+ output_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Value,
+ should_exist=True,
+ ),
+ },
+ 'Batch': {
+ 'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Array,
+ should_exist=True,
+ ),
+ },
}
####################
# - Properties
####################
- sim_info_available: bool = bl_cache.BLField(False)
- sim_info_invalidated: bool = bl_cache.BLField(False)
- uploaded_task_id: str = bl_cache.BLField('')
+ uploaded_task_ids: tuple[str, ...] = bl_cache.BLField(())
####################
- # - Computed - Sim
+ # - Properties: Socket -> Props
####################
- @bl_cache.cached_bl_property()
- def sim(self) -> td.Simulation | None:
- sim = self._compute_input('Sim')
- has_sim = not ct.FlowSignal.check(sim)
+ @events.on_value_changed(
+ socket_name={'Sim': {FK.Value, FK.Array}},
+ )
+ def on_sims_changed(self) -> None:
+ """Regenerate the simulation property on changes."""
+ self.sims = bl_cache.Signal.InvalidateCache
+
+ @events.on_value_changed(
+ socket_name={'Cloud Task': {FK.Value}},
+ )
+ def on_base_cloud_task_changed(self) -> None:
+ """Regenerate the cloud task property on changes."""
+ self.base_cloud_task = bl_cache.Signal.InvalidateCache
+
+ ####################
+ # - Properties: Socket Alias
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'active_socket_set'})
+ def sims(self) -> SimArray | None:
+ """The complete description of all simulation data input."""
+ if self.active_socket_set == 'Single':
+ sim_data_value = self._compute_input('Sim', kind=FK.Value)
+ has_sim_data_value = not FS.check(sim_data_value)
+ if has_sim_data_value:
+ return {(): sim_data_value}
+
+ elif self.active_socket_set == 'Batch':
+ sim_data_array = self._compute_input('Sims', kind=FK.Array)
+ has_sim_data_array = not FS.check(sim_data_array)
+ if has_sim_data_array:
+ return sim_data_array
- if has_sim:
- return sim
return None
@bl_cache.cached_bl_property()
- def total_monitor_data(self) -> float | None:
- if self.sim is not None:
- return sum(self.sim.monitors_data_size.values())
- return None
-
- ####################
- # - Computed - New Cloud Task
- ####################
- @property
- def new_cloud_task(self) -> ct.NewSimCloudTask | None:
- """Retrieve the current new cloud task from the input socket.
-
- If one can't be loaded, return None.
- """
- new_cloud_task = self._compute_input(
+ def base_cloud_task(self) -> ct.NewSimCloudTask | None:
+ """The complete description of all simulation input objects."""
+ base_cloud_task = self._compute_input(
'Cloud Task',
kind=ct.FlowKind.Value,
)
- has_new_cloud_task = not ct.FlowSignal.check(new_cloud_task)
-
- if has_new_cloud_task and new_cloud_task.task_name != '':
- return new_cloud_task
+ has_base_cloud_task = not FS.check(base_cloud_task)
+ if has_base_cloud_task and base_cloud_task.task_name != '':
+ return base_cloud_task
return None
####################
- # - Computed - Uploaded Cloud Task
+ # - Properties: Simulations
####################
- @property
- def uploaded_task(self) -> tdcloud.CloudTask | None:
- """Retrieve the uploaded cloud task.
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def sims_valid(
+ self,
+ ) -> (
+ dict[tuple[tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...]], bool] | None
+ ):
+ """Whether all sims are valid."""
+ if self.sims is not None:
+ validity = {}
+ for k, sim in self.sims.items(): # noqa: B007
+ try:
+ pass ## TODO: VERY slow, batch checking is infeasible
+ # sim.validate_pre_upload(source_required=True)
+ except td.exceptions.SetupError:
+ validity[k] = False
+ else:
+ validity[k] = True
- If one can't be loaded, return None.
- """
- has_uploaded_task = self.uploaded_task_id != ''
- has_new_cloud_task = self.new_cloud_task is not None
-
- if has_uploaded_task and has_new_cloud_task:
- return tdcloud.TidyCloudTasks.tasks(self.new_cloud_task.cloud_folder).get(
- self.uploaded_task_id
- )
+ return validity
return None
- @property
- def uploaded_task_info(self) -> tdcloud.CloudTask | None:
- """Retrieve the uploaded cloud task.
-
- If one can't be loaded, return None.
- """
- uploaded_task = self.uploaded_task
- if uploaded_task is not None:
- return tdcloud.TidyCloudTasks.task_info(self.uploaded_task_id)
+ ####################
+ # - Properties: Tasks
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'uploaded_task_ids'})
+ def uploaded_task_infos(self) -> list[tdcloud.CloudTask | None] | None:
+ """Retrieve information about the uploaded cloud tasks."""
+ if self.uploaded_task_ids:
+ return [
+ tdcloud.TidyCloudTasks.task_info(task_id)
+ for task_id in self.uploaded_task_ids
+ ]
return None
- @bl_cache.cached_bl_property()
- def uploaded_est_cost(self) -> float | None:
- task_info = self.uploaded_task_info
- if task_info is not None:
- est_cost = task_info.cost_est()
- if est_cost is not None:
- return est_cost
+ @bl_cache.cached_bl_property(depends_on={'uploaded_task_infos'})
+ def est_costs(self) -> list[float | None] | None:
+ """Estimate the FlexCredit cost of each uploaded task."""
+ if self.uploaded_task_infos is not None and all(
+ task_info is not None for task_info in self.uploaded_task_infos
+ ):
+ return [task_info.cost_est() for task_info in self.uploaded_task_infos]
+ return None
+ @bl_cache.cached_bl_property(depends_on={'est_costs'})
+ def total_est_cost(self) -> list[float | None] | None:
+ """Estimate the total FlexCredits cost of all uploaded tasks."""
+ if self.est_costs is not None and all(
+ est_cost is not None for est_cost in self.est_costs
+ ):
+ return sum(self.est_costs)
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'uploaded_task_infos'})
+ def real_costs(self) -> list[float | None] | None:
+ """Estimate the FlexCredit cost of each uploaded task."""
+ if self.uploaded_task_infos is not None and all(
+ task_info is not None for task_info in self.uploaded_task_infos
+ ):
+ return [task_info.cost_real for task_info in self.uploaded_task_infos]
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'real_costs'})
+ def total_real_cost(self) -> list[float | None] | None:
+ """Estimate the total FlexCredits cost of all uploaded tasks."""
+ if self.real_costs is not None and all(
+ real_cost is not None for real_cost in self.real_costs
+ ):
+ return sum(self.real_costs)
return None
####################
# - Computed - Combined
####################
- @bl_cache.cached_bl_property()
- def is_sim_uploadable(self) -> bool:
- if (
- self.sim is not None
- and self.uploaded_task_id == ''
- and self.new_cloud_task is not None
- and self.new_cloud_task.task_name != ''
- ):
- try:
- self.sim.validate_pre_upload(source_required=True)
- except:
- log.exception()
- return False
- else:
- return True
- return False
+ @bl_cache.cached_bl_property(depends_on={'sims_valid'})
+ def sims_uploadable(self) -> bool:
+ """Whether all simulations can be uploaded."""
+ return self.sims_valid is not None and all(self.sims_valid.values())
####################
# - UI
####################
- def draw_operators(self, context, layout):
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'uploaded_task_infos',
+ 'total_est_cost',
+ 'total_real_cost',
+ }
+ )
+ def task_labels(self) -> SimArrayInfo | None:
+ """Pre-processed labels for efficient drawing of task info."""
+ if self.uploaded_task_infos is not None and all(
+ task_info is not None for task_info in self.uploaded_task_infos
+ ):
+ return {
+ task_info.task_id: [
+ f'Task: {task_info.task_name}',
+ ('Status', task_info.status),
+ (
+ 'Est.',
+ (
+ f'{self.total_est_cost:.2f} creds'
+ if self.total_est_cost is not None
+ else 'TBD...'
+ ),
+ ),
+ (
+ 'Real',
+ (
+ f'{self.total_real_cost:.2f} creds'
+ if self.total_real_cost is not None
+ else 'TBD'
+ ),
+ ),
+ ]
+ for task_info in self.uploaded_task_infos
+ }
+ return None
+
+ def draw_operators(self, _, layout):
+ """Draw operators for uploading/releasing simulations."""
# Row: Upload Sim Buttons
row = layout.row(align=True)
row.operator(
ct.OperatorType.NodeUploadSimulation,
text='Upload',
)
- if self.uploaded_task_id:
+ if self.uploaded_task_ids:
row.operator(
ct.OperatorType.NodeReleaseUploadedTask,
icon='LOOP_BACK',
text='',
)
- def draw_info(self, context, layout):
+ def draw_info(self, _, layout):
+ """Draw information relevant for simulation uploading."""
# Connection Info
auth_icon = 'CHECKBOX_HLT' if tdcloud.IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
conn_icon = 'CHECKBOX_HLT' if tdcloud.IS_ONLINE else 'CHECKBOX_DEHLT'
- row = layout.row()
+ box = layout.box()
+
+ # Cloud Info
+ row = box.row()
row.alignment = 'CENTER'
row.label(text='Cloud Status')
- box = layout.box()
+
split = box.split(factor=0.85)
- ## Split: Left Column
col = split.column(align=False)
col.label(text='Authed')
col.label(text='Connected')
- ## Split: Right Column
col = split.column(align=False)
col.label(icon=auth_icon)
col.label(icon=conn_icon)
- # Simulation Info
- if self.sim is not None:
- row = layout.row()
- row.alignment = 'CENTER'
- row.label(text='Sim Info')
- box = layout.box()
+ if self.task_labels is not None:
+ for labels in self.task_labels.values():
+ row = layout.row(align=True)
+ row.alignment = 'CENTER'
+ row.label(text='Task Status')
- if self.sim_info_invalidated:
- box.operator(ct.OperatorType.NodeRecomputeSimInfo, text='Regenerate')
- else:
- split = box.split(factor=0.5)
+ for el in labels:
+ # Header
+ if isinstance(el, str):
+ box = layout.box()
+ row = box.row(align=True)
+ row.alignment = 'CENTER'
+ row.label(text=el)
- ## Split: Left Column
- col = split.column(align=False)
- col.label(text='𝝨 Data')
+ split = box.split(factor=0.4)
+ col_l = split.column(align=True)
+ col_r = split.column(align=True)
- ## Split: Right Column
- col = split.column(align=False)
- col.alignment = 'RIGHT'
- col.label(text=f'{self.total_monitor_data / 1_000_000:.2f}MB')
+ # Label Pair
+ elif isinstance(el, tuple):
+ col_l.label(text=el[0])
+ col_r.label(text=el[1])
- if self.uploaded_task_info is not None:
- # Uploaded Task Information
- box = layout.box()
- split = box.split(factor=0.6)
+ else:
+ raise TypeError
- ## Split: Left Column
- col = split.column(align=False)
- col.label(text='Status')
- col.label(text='Est. Cost')
- col.label(text='Real Cost')
-
- ## Split: Right Column
- cost_est = (
- f'{self.uploaded_est_cost:.2f}'
- if self.uploaded_est_cost is not None
- else 'TBD'
- )
- cost_real = (
- f'{self.uploaded_task_info.cost_real:.2f}'
- if self.uploaded_task_info.cost_real is not None
- else 'TBD'
- )
-
- col = split.column(align=False)
- col.alignment = 'RIGHT'
- col.label(text=self.uploaded_task_info.status.capitalize())
- col.label(text=f'{cost_est} creds')
- col.label(text=f'{cost_real} creds')
-
- # Connection Information
+ break
####################
# - Events
####################
@events.on_value_changed(
- socket_name='Sim',
- run_on_init=True,
- props={'sim_info_available', 'sim_info_invalidated'},
- )
- def on_sim_changed(self, props) -> None:
- # Sim Linked | First Value Change
- if self.inputs['Sim'].is_linked and not props['sim_info_available']:
- log.debug('%s: First Change; Mark Sim Info Available', self.sim_node_name)
- self.sim = bl_cache.Signal.InvalidateCache
- self.total_monitor_data = bl_cache.Signal.InvalidateCache
- self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
- self.sim_info_available = True
-
- # Sim Linked | Second Value Change
- if (
- self.inputs['Sim'].is_linked
- and props['sim_info_available']
- and not props['sim_info_invalidated']
- ):
- log.debug('%s: Second Change; Mark Sim Info Invalided', self.sim_node_name)
- self.sim_info_invalidated = True
-
- # Sim Linked | Nth Time
- ## -> Danger of infinite expensive recompute of the sim every change.
- ## -> Instead, user must manually set "available & not invalidated".
- ## -> The UI should explain that the caches are dry.
- ## -> The UI should also provide such a "hydration" button.
-
- # Sim Not Linked
- ## -> If the sim is straight-up not available, cache needs changing.
- ## -> Luckily, since we know there's no sim, invalidation is cheap.
- ## -> Ends up being a "circuit breaker" for sim_info_invalidated.
- elif not self.inputs['Sim'].is_linked:
- log.debug(
- '%s: Unlinked; Short Circuit the Sim Info Cache', self.sim_node_name
- )
- self.sim = bl_cache.Signal.InvalidateCache
- self.total_monitor_data = bl_cache.Signal.InvalidateCache
- self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
- self.sim_info_available = False
- self.sim_info_invalidated = False
-
- @events.on_value_changed(
- socket_name='Cloud Task',
- run_on_init=True,
- )
- def on_new_cloud_task_changed(self):
- self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
-
- @events.on_value_changed(
- # Trigger
- prop_name='uploaded_task_id',
- run_on_init=True,
+ prop_name='uploaded_task_ids',
# Loaded
- props={'uploaded_task_id'},
+ props={'uploaded_task_ids'},
)
def on_uploaded_task_changed(self, props):
- log.debug('Uploaded Task Changed')
- self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
+ """When uploaded tasks change, take appropriate action.
- if props['uploaded_task_id'] != '':
+ - Enable/Disable Lock: To prevent node-tree modifications that would invalidate the validity of uploaded tasks.
+ - Ensure Est Cost: Repeatedly try to load the estimated cost of all tasks, until all are available.
+ """
+ uploaded_task_ids = props['uploaded_task_ids']
+
+ # Lock
+ if uploaded_task_ids:
self.trigger_event(ct.FlowEvent.EnableLock)
self.locked = False
+ # Force Computation of Estimated Cost
+ ## -> Try recomputing the estimated cost of all tasks.
+ ## -> Once all are non-None, stop.
+ max_tries = 20
+ for _ in range(max_tries):
+ self.est_costs = bl_cache.Signal.InvalidateCache
+ if self.total_est_cost is not None:
+ break
+
else:
self.trigger_event(ct.FlowEvent.DisableLock)
- max_tries = 10
- for _ in range(max_tries):
- self.uploaded_est_cost = bl_cache.Signal.InvalidateCache
- if self.uploaded_est_cost is not None:
- break
-
####################
- # - Outputs
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Cloud Task',
- props={'uploaded_task_id', 'uploaded_task'},
+ kind=ct.FlowKind.Value,
+ # Loaded
+ props={'uploaded_task_ids'},
)
def compute_cloud_task(self, props) -> tdcloud.CloudTask | None:
- if props['uploaded_task_id'] != '':
- return props['uploaded_task']
+ """A single uploaded cloud task, when there only is one."""
+ uploaded_task_ids = props['uploaded_task_ids']
- return ct.FlowSignal.FlowPending
+ if uploaded_task_ids is not None and len(uploaded_task_ids) == 1:
+ cloud_task = tdcloud.TidyCloudTasks.task(uploaded_task_ids[0])
+ if cloud_task is not None:
+ return cloud_task
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Array
+ ####################
+ @events.computes_output_socket(
+ 'Cloud Tasks',
+ kind=ct.FlowKind.Array,
+ # Loaded
+ props={'uploaded_task_ids'},
+ )
+ def compute_cloud_tasks(self, props) -> tdcloud.CloudTask | None:
+ """All uploaded cloud task, when there are more than one."""
+ uploaded_task_ids = props['uploaded_task_ids']
+
+ if len(uploaded_task_ids) > 1:
+ cloud_tasks = [
+ tdcloud.TidyCloudTasks.task(task_id) for task_id in uploaded_task_ids
+ ]
+ if all(cloud_task is not None for cloud_task in cloud_tasks):
+ return cloud_tasks
+ return FS.FlowPending
####################
@@ -425,7 +503,6 @@ BL_REGISTER = [
UploadSimulation,
ReleaseUploadedTask,
Tidy3DWebExporterNode,
- RecomputeSimInfo,
]
BL_NODES = {
ct.NodeType.Tidy3DWebExporter: (ct.NodeCategory.MAXWELLSIM_OUTPUTS_WEBEXPORTERS)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/__init__.py
index 97c7d15..1d58108 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/__init__.py
@@ -14,21 +14,28 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-# from . import sim_grid
-# from . import sim_grid_axes
-from . import combine, fdtd_sim, sim_domain
+from . import (
+ bound_cond_faces,
+ bound_conds,
+ fdtd_sim,
+ sim_domain,
+ sim_grid,
+ sim_grid_axes,
+)
BL_REGISTER = [
- *combine.BL_REGISTER,
- *sim_domain.BL_REGISTER,
- # *sim_grid.BL_REGISTER,
- # *sim_grid_axes.BL_REGISTER,
*fdtd_sim.BL_REGISTER,
+ *sim_domain.BL_REGISTER,
+ *bound_conds.BL_REGISTER,
+ *bound_cond_faces.BL_REGISTER,
+ *sim_grid.BL_REGISTER,
+ *sim_grid_axes.BL_REGISTER,
]
BL_NODES = {
- **combine.BL_NODES,
- **sim_domain.BL_NODES,
- # **sim_grid.BL_NODES,
- # **sim_grid_axes.BL_NODES,
**fdtd_sim.BL_NODES,
+ **sim_domain.BL_NODES,
+ **bound_conds.BL_NODES,
+ **bound_cond_faces.BL_NODES,
+ **sim_grid.BL_NODES,
+ **sim_grid_axes.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/__init__.py
similarity index 100%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/__init__.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/__init__.py
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/absorbing_bound_cond.py
similarity index 63%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/absorbing_bound_cond.py
index f03eb77..b2887ec 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/absorbing_bound_cond.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/absorbing_bound_cond.py
@@ -22,8 +22,8 @@ import bpy
import sympy as sp
import tidy3d as td
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
r"""A boundary condition that generically (adiabatically) absorbs outgoing energy, by gradually ramping up the strength of the conductor over many layers, until a final PEC layer.
@@ -78,7 +81,7 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
),
'σ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
- default_value=sp.Matrix([0, 1.5]),
+ default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
},
@@ -91,6 +94,11 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's reported info.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
if self.active_socket_set == 'Full':
box = layout.box()
row = box.row()
@@ -115,118 +123,71 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
@events.computes_output_socket(
'BC',
# Loaded
- props={'active_socket_set'},
- input_sockets={
- 'Layers',
- 'σ Order',
- 'σ Range',
+ outscks_kinds={
+ 'BC': {FK.Func, FK.Params},
},
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
- },
- output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
)
- def compute_bc_value(self, props, input_sockets, output_sockets) -> td.Absorber:
+ def compute_bc_value(self, output_sockets) -> td.Absorber | FS:
r"""Computes the adiabatic absorber boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
- output_params = output_sockets['BC']
- layers = input_sockets['Layers']
-
- has_output_params = not ct.FlowSignal.check(output_params)
- has_layers = not ct.FlowSignal.check(layers)
-
- active_socket_set = props['active_socket_set']
- if has_layers and has_output_params and not output_params.symbols:
- # Simple PML
- if active_socket_set == 'Simple':
- return td.Absorber(num_layers=layers)
-
- # Full PML
- sig_order = input_sockets['σ Order']
- sig_range = input_sockets['σ Range']
-
- has_sig_order = not ct.FlowSignal.check(sig_order)
- has_sig_range = not ct.FlowSignal.check(sig_range)
-
- if has_sig_order and has_sig_range:
- return td.Absorber(
- num_layers=layers,
- parameters=td.AbsorberParams(
- sigma_order=sig_order,
- sigma_min=sig_range[0],
- sigma_max=sig_range[1],
- ),
- )
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['BC'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Layers',
- 'σ Order',
- 'σ Range',
+ inscks_kinds={
+ 'Layers': FK.Func,
+ 'σ Order': FK.Func,
+ 'σ Range': FK.Func,
},
- input_socket_kinds={
- 'Layers': ct.FlowKind.Func,
- 'σ Order': ct.FlowKind.Func,
- 'σ Range': ct.FlowKind.Func,
- },
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
- },
- output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
+ input_sockets_optional={'σ Order', 'σ Range'},
)
- def compute_bc_func(self, props, input_sockets, output_sockets) -> td.Absorber:
+ def compute_bc_func(self, props, input_sockets) -> td.Absorber:
r"""Computes the adiabatic absorber boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
layers = input_sockets['Layers']
-
- has_layers = not ct.FlowSignal.check(layers)
-
active_socket_set = props['active_socket_set']
- if has_layers:
- # Simple PML
- if active_socket_set == 'Simple':
+
+ match active_socket_set:
+ case 'Simple':
return layers.compose_within(
enclosing_func=lambda _layers: td.Absorber(num_layers=_layers),
supports_jax=False,
)
+ case 'Full':
+ sig_order = input_sockets['σ Order']
+ sig_range = input_sockets['σ Range']
- # Full PML
- sig_order = input_sockets['σ Order']
- sig_range = input_sockets['σ Range']
+ has_sig_order = not FS.check(sig_order)
+ has_sig_range = not FS.check(sig_range)
- has_sig_order = not ct.FlowSignal.check(sig_order)
- has_sig_range = not ct.FlowSignal.check(sig_range)
-
- if has_sig_order and has_sig_range:
- return (layers | sig_order | sig_range).compose_within(
- enclosing_func=lambda els: td.Absorber(
- num_layers=els[0],
- parameters=td.AbsorberParams(
- sigma_order=els[1],
- sigma_min=els[2][0],
- sigma_max=els[2][1],
+ if has_sig_order and has_sig_range:
+ return (layers | sig_order | sig_range).compose_within(
+ enclosing_func=lambda els: td.Absorber(
+ num_layers=els[0],
+ parameters=td.AbsorberParams(
+ sigma_order=els[1],
+ sigma_min=els[2].item(0),
+ sigma_max=els[2].item(1),
+ ),
),
- ),
- supports_jax=False,
- )
+ supports_jax=False,
+ )
+
return ct.FlowSignal.FlowPending
####################
@@ -234,44 +195,35 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Layers',
- 'σ Order',
- 'σ Range',
- },
- input_socket_kinds={
- 'Layers': ct.FlowKind.Params,
- 'σ Order': ct.FlowKind.Params,
- 'σ Range': ct.FlowKind.Params,
- },
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
+ inscks_kinds={
+ 'Layers': FK.Params,
+ 'σ Order': FK.Params,
+ 'σ Range': FK.Params,
},
+ input_sockets_optional={'σ Order', 'σ Range'},
)
def compute_params(self, props, input_sockets) -> td.Box:
+ """Aggregate the function parameters needed by the absorbing BC."""
layers = input_sockets['Layers']
-
- has_layers = not ct.FlowSignal.check(layers)
-
active_socket_set = props['active_socket_set']
- if has_layers:
- # Simple PML
- if active_socket_set == 'Simple':
+
+ match active_socket_set:
+ case 'Simple':
return layers
- # Full PML
- sig_order = input_sockets['σ Order']
- sig_range = input_sockets['σ Range']
+ case 'Full':
+ sig_order = input_sockets['σ Order']
+ sig_range = input_sockets['σ Range']
- has_sig_order = not ct.FlowSignal.check(sig_order)
- has_sig_range = not ct.FlowSignal.check(sig_range)
+ has_sig_order = not FS.check(sig_order)
+ has_sig_range = not FS.check(sig_range)
+
+ if has_sig_order and has_sig_range:
+ return layers | sig_order | sig_range
- if has_sig_order and has_sig_range:
- return layers | sig_order | sig_range
return ct.FlowSignal.FlowPending
@@ -281,4 +233,6 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
AdiabAbsorbBoundCondNode,
]
-BL_NODES = {ct.NodeType.AdiabAbsorbBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
+BL_NODES = {
+ ct.NodeType.AdiabAbsorbBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/bloch_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/bloch_bound_cond.py
similarity index 87%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/bloch_bound_cond.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/bloch_bound_cond.py
index 49b0b62..8bae791 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/bloch_bound_cond.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/bloch_bound_cond.py
@@ -29,6 +29,9 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class BlochBoundCondNode(base.MaxwellSimNode):
r"""A boundary condition that declares an "infinitely repeating" window, by applying Bloch's theorem to accurately describe how a boundary would behave if it were interacting with an infinitely repeating simulation structure.
@@ -117,7 +120,7 @@ class BlochBoundCondNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = {
'Naive': {},
'Source-Derived': {
- 'Angled Source': sockets.MaxwellSourceSocketDef(),
+ 'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
## TODO: Constrain to gaussian beam, plane wafe, and tfsf
'Sim Domain': sockets.MaxwellSimDomainSocketDef(),
},
@@ -138,10 +141,20 @@ class BlochBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interface of the node's properties.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
if self.active_socket_set == 'Source-Derived':
layout.prop(self, self.blfields['valid_sim_axis'], expand=True)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's reported info.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
if self.active_socket_set == 'Manual':
box = layout.box()
row = box.row()
@@ -204,7 +217,7 @@ class BlochBoundCondNode(base.MaxwellSimNode):
'Bloch Vector': True,
},
output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
+ output_socket_kinds={'BC': FK.Params},
)
def compute_value(
self, props, input_sockets, output_sockets
@@ -217,9 +230,9 @@ class BlochBoundCondNode(base.MaxwellSimNode):
- **Manual**: Set the Bloch vector to the user-specified value.
"""
output_params = output_sockets['BC']
- has_output_params = not ct.FlowSignal.check(output_params)
+ has_output_params = not FS.check(output_params)
if not has_output_params or (has_output_params and output_params.symbols):
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
active_socket_set = props['active_socket_set']
match active_socket_set:
@@ -230,8 +243,8 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
- has_angled_source = not ct.FlowSignal.check(angled_source)
- has_sim_domain = not ct.FlowSignal.check(sim_domain)
+ has_angled_source = not FS.check(angled_source)
+ has_sim_domain = not FS.check(sim_domain)
if has_angled_source and has_sim_domain:
valid_sim_axis = props['valid_sim_axis']
@@ -241,22 +254,22 @@ class BlochBoundCondNode(base.MaxwellSimNode):
axis=valid_sim_axis.axis,
medium=sim_domain['medium'],
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
- has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
+ has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return td.BlochBoundary(bloch_vec=bloch_vector)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'active_socket_set', 'valid_sim_axis'},
input_sockets={
@@ -265,29 +278,18 @@ class BlochBoundCondNode(base.MaxwellSimNode):
'Bloch Vector',
},
input_socket_kinds={
- 'Angled Source': ct.FlowKind.Func,
- 'Sim Domain': ct.FlowKind.Func,
- 'Bloch Vector': ct.FlowKind.Func,
+ 'Angled Source': FK.Func,
+ 'Sim Domain': FK.Func,
+ 'Bloch Vector': FK.Func,
},
- input_sockets_optional={
- 'Angled Source': True,
- 'Sim Domain': True,
- 'Bloch Vector': True,
- },
- output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
+ input_sockets_optional={'Angled Source', 'Sim Domain', 'Bloch Vector'},
)
- def compute_bc_func(self, props, input_sockets, output_sockets) -> td.Absorber:
- r"""Computes the adiabatic absorber boundary condition based on the active socket set.
+ def compute_bc_func(self, props, input_sockets) -> td.Absorber:
+ r"""Computes the bloch boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
- output_params = output_sockets['BC']
- has_output_params = not ct.FlowSignal.check(output_params)
- if not has_output_params:
- return ct.FlowSignal.FlowPending
-
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Naive':
@@ -300,8 +302,8 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
- has_angled_source = not ct.FlowSignal.check(angled_source)
- has_sim_domain = not ct.FlowSignal.check(sim_domain)
+ has_angled_source = not FS.check(angled_source)
+ has_sim_domain = not FS.check(sim_domain)
if has_angled_source and has_sim_domain:
valid_sim_axis = props['valid_sim_axis']
@@ -314,11 +316,11 @@ class BlochBoundCondNode(base.MaxwellSimNode):
),
supports_jax=False,
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
- has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
+ has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return bloch_vector.compose_within(
@@ -327,33 +329,25 @@ class BlochBoundCondNode(base.MaxwellSimNode):
),
supports_jax=False,
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Angled Source',
- 'Sim Domain',
- 'Bloch Vector',
- },
- input_socket_kinds={
- 'Angled Source': ct.FlowKind.Params,
- 'Sim Domain': ct.FlowKind.Params,
- 'Bloch Vector': ct.FlowKind.Params,
- },
- input_sockets_optional={
- 'Angled Source': True,
- 'Sim Domain': True,
- 'Bloch Vector': True,
+ inscks_kinds={
+ 'Angled Source': FK.Params,
+ 'Sim Domain': FK.Params,
+ 'Bloch Vector': FK.Params,
},
+ input_sockets_optional={'Angled Source', 'Sim Domain', 'Bloch Vector'},
)
- def compute_bc_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_bc_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
+ """Aggregate the function parameters needed by the bloch BC."""
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Naive':
@@ -363,20 +357,20 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
- has_angled_source = not ct.FlowSignal.check(angled_source)
- has_sim_domain = not ct.FlowSignal.check(sim_domain)
+ has_angled_source = not FS.check(angled_source)
+ has_sim_domain = not FS.check(sim_domain)
if has_sim_domain and has_angled_source:
return angled_source | sim_domain
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
- has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
+ has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return bloch_vector
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
@@ -385,4 +379,6 @@ class BlochBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
BlochBoundCondNode,
]
-BL_NODES = {ct.NodeType.BlochBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
+BL_NODES = {
+ ct.NodeType.BlochBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/pml_bound_cond.py
similarity index 59%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/pml_bound_cond.py
index 1490bdd..d016f05 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_cond_nodes/pml_bound_cond.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_cond_faces/pml_bound_cond.py
@@ -22,8 +22,8 @@ import bpy
import sympy as sp
import tidy3d as td
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class PMLBoundCondNode(base.MaxwellSimNode):
r"""A "Perfectly Matched Layer" boundary condition, which is a theoretical medium that attempts to _perfectly_ absorb all outgoing waves, so as to represent "infinite space" in FDTD simulations.
@@ -82,7 +85,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'σ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
- default_value=sp.Matrix([0, 1.5]),
+ default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
'κ Order': sockets.ExprSocketDef(
@@ -94,7 +97,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'κ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
- default_value=sp.Matrix([0, 1.5]),
+ default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
'α Order': sockets.ExprSocketDef(
@@ -106,7 +109,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
- default_value=sp.Matrix([0, 1.5]),
+ default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
},
@@ -119,6 +122,11 @@ class PMLBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
+ """Draw the user interfaces of the node's reported info.
+
+ Parameters:
+ layout: UI target for drawing.
+ """
if self.active_socket_set == 'Full':
box = layout.box()
row = box.row()
@@ -144,11 +152,19 @@ class PMLBoundCondNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Layers',
+ inscks_kinds={
+ 'Layers': FK.Value,
+ 'σ Order': FK.Value,
+ 'σ Range': FK.Value,
+ 'κ Order': FK.Value,
+ 'κ Range': FK.Value,
+ 'α Order': FK.Value,
+ 'α Range': FK.Value,
+ },
+ input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@@ -156,16 +172,8 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
- 'κ Order': True,
- 'κ Range': True,
- 'α Order': True,
- 'α Range': True,
- },
output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
+ output_socket_kinds={'BC': FK.Params},
)
def compute_pml_value(self, props, input_sockets, output_sockets) -> td.PML:
r"""Computes the PML boundary condition based on the active socket set.
@@ -173,13 +181,11 @@ class PMLBoundCondNode(base.MaxwellSimNode):
- **Simple**: Use `tidy3d`'s default parameters for defining the PML conductor (apart from number of layers).
- **Full**: Use the user-defined $\sigma$, $\kappa$, and $\alpha$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
- output_params = output_sockets['BC']
layers = input_sockets['Layers']
+ output_params = output_sockets['BC']
+ has_output_params = not FS.check(output_params)
- has_layers = not ct.FlowSignal.check(layers)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and has_layers and not output_params.symbols:
+ if has_output_params and not output_params.symbols:
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Simple':
@@ -193,12 +199,12 @@ class PMLBoundCondNode(base.MaxwellSimNode):
alpha_order = input_sockets['α Order']
alpha_range = input_sockets['α Range']
- has_sigma_order = not ct.FlowSignal.check(sigma_order)
- has_sigma_range = not ct.FlowSignal.check(sigma_range)
- has_kappa_order = not ct.FlowSignal.check(kappa_order)
- has_kappa_range = not ct.FlowSignal.check(kappa_range)
- has_alpha_order = not ct.FlowSignal.check(alpha_order)
- has_alpha_range = not ct.FlowSignal.check(alpha_range)
+ has_sigma_order = not FS.check(sigma_order)
+ has_sigma_range = not FS.check(sigma_range)
+ has_kappa_order = not FS.check(kappa_order)
+ has_kappa_range = not FS.check(kappa_range)
+ has_alpha_order = not FS.check(alpha_order)
+ has_alpha_range = not FS.check(alpha_range)
if (
has_sigma_order
@@ -223,18 +229,26 @@ class PMLBoundCondNode(base.MaxwellSimNode):
),
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Layers',
+ inscks_kinds={
+ 'Layers': FK.Func,
+ 'σ Order': FK.Func,
+ 'σ Range': FK.Func,
+ 'κ Order': FK.Func,
+ 'κ Range': FK.Func,
+ 'α Order': FK.Func,
+ 'α Range': FK.Func,
+ },
+ input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@@ -242,101 +256,86 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
- input_socket_kinds={
- 'Layers': ct.FlowKind.Func,
- 'σ Order': ct.FlowKind.Func,
- 'σ Range': ct.FlowKind.Func,
- 'κ Order': ct.FlowKind.Func,
- 'κ Range': ct.FlowKind.Func,
- 'α Order': ct.FlowKind.Func,
- 'α Range': ct.FlowKind.Func,
- },
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
- 'κ Order': True,
- 'κ Range': True,
- 'α Order': True,
- 'α Range': True,
- },
- output_sockets={'BC'},
- output_socket_kinds={'BC': ct.FlowKind.Params},
)
- def compute_pml_func(self, props, input_sockets, output_sockets) -> td.PML:
- output_params = output_sockets['BC']
+ def compute_pml_func(self, props, input_sockets) -> td.PML:
+ r"""Computes the PML boundary condition as a lazy function."""
layers = input_sockets['Layers']
+ active_socket_set = props['active_socket_set']
- has_output_params = not ct.FlowSignal.check(output_params)
- has_layers = not ct.FlowSignal.check(layers)
+ match active_socket_set:
+ case 'Simple':
+ return layers.compose_within(
+ enclosing_func=lambda layers: td.PML(num_layers=layers),
+ supports_jax=False,
+ )
- if has_output_params and has_layers:
- active_socket_set = props['active_socket_set']
- match active_socket_set:
- case 'Simple':
- return layers.compose_within(
- enclosing_func=lambda layers: td.PML(num_layers=layers),
- supports_jax=False,
+ case 'Full':
+ sigma_order = input_sockets['σ Order']
+ sigma_range = input_sockets['σ Range']
+ kappa_order = input_sockets['κ Order']
+ kappa_range = input_sockets['κ Range']
+ alpha_order = input_sockets['α Order']
+ alpha_range = input_sockets['α Range']
+
+ has_sigma_order = not FS.check(sigma_order)
+ has_sigma_range = not FS.check(sigma_range)
+ has_kappa_order = not FS.check(kappa_order)
+ has_kappa_range = not FS.check(kappa_range)
+ has_alpha_order = not FS.check(alpha_order)
+ has_alpha_range = not FS.check(alpha_range)
+
+ if (
+ has_sigma_order
+ and has_sigma_range
+ and has_kappa_order
+ and has_kappa_range
+ and has_alpha_order
+ and has_alpha_range
+ ):
+ return (
+ sigma_order
+ | sigma_range
+ | kappa_order
+ | kappa_range
+ | alpha_order
+ | alpha_range
+ ).compose_within(
+ enclosing_func=lambda els: td.PML(
+ num_layers=layers,
+ parameters=td.PMLParams(
+ sigma_order=els[0],
+ sigma_min=els[1][0],
+ sigma_max=els[1][1],
+ kappa_order=els[2],
+ kappa_min=els[3][0],
+ kappa_max=els[3][1],
+ alpha_order=els[4][1],
+ alpha_min=els[5][0],
+ alpha_max=els[5][1],
+ ),
+ )
)
- case 'Full':
- sigma_order = input_sockets['σ Order']
- sigma_range = input_sockets['σ Range']
- kappa_order = input_sockets['κ Order']
- kappa_range = input_sockets['κ Range']
- alpha_order = input_sockets['α Order']
- alpha_range = input_sockets['α Range']
-
- has_sigma_order = not ct.FlowSignal.check(sigma_order)
- has_sigma_range = not ct.FlowSignal.check(sigma_range)
- has_kappa_order = not ct.FlowSignal.check(kappa_order)
- has_kappa_range = not ct.FlowSignal.check(kappa_range)
- has_alpha_order = not ct.FlowSignal.check(alpha_order)
- has_alpha_range = not ct.FlowSignal.check(alpha_range)
-
- if (
- has_sigma_order
- and has_sigma_range
- and has_kappa_order
- and has_kappa_range
- and has_alpha_order
- and has_alpha_range
- ):
- return (
- sigma_order
- | sigma_range
- | kappa_order
- | kappa_range
- | alpha_order
- | alpha_range
- ).compose_within(
- enclosing_func=lambda els: td.PML(
- num_layers=layers,
- parameters=td.PMLParams(
- sigma_order=els[0],
- sigma_min=els[1][0],
- sigma_max=els[1][1],
- kappa_order=els[2],
- kappa_min=els[3][0],
- kappa_max=els[3][1],
- alpha_order=els[4][1],
- alpha_min=els[5][0],
- alpha_max=els[5][1],
- ),
- )
- )
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BC',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'Layers',
+ inscks_kinds={
+ 'Layers': FK.Params,
+ 'σ Order': FK.Params,
+ 'σ Range': FK.Params,
+ 'κ Order': FK.Params,
+ 'κ Range': FK.Params,
+ 'α Order': FK.Params,
+ 'α Range': FK.Params,
+ },
+ input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@@ -344,72 +343,53 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
- input_socket_kinds={
- 'Layers': ct.FlowKind.Params,
- 'σ Order': ct.FlowKind.Params,
- 'σ Range': ct.FlowKind.Params,
- 'κ Order': ct.FlowKind.Params,
- 'κ Range': ct.FlowKind.Params,
- 'α Order': ct.FlowKind.Params,
- 'α Range': ct.FlowKind.Params,
- },
- input_sockets_optional={
- 'σ Order': True,
- 'σ Range': True,
- 'κ Order': True,
- 'κ Range': True,
- 'α Order': True,
- 'α Range': True,
- },
)
- def compute_pml_params(self, props, input_sockets) -> td.PML:
+ def compute_pml_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
r"""Computes the PML boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the PML conductor (apart from number of layers).
- **Full**: Use the user-defined $\sigma$, $\kappa$, and $\alpha$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
layers = input_sockets['Layers']
- has_layers = not ct.FlowSignal.check(layers)
+ active_socket_set = props['active_socket_set']
- if has_layers:
- active_socket_set = props['active_socket_set']
- match active_socket_set:
- case 'Simple':
- return layers
+ match active_socket_set:
+ case 'Simple':
+ return layers
- case 'Full':
- sigma_order = input_sockets['σ Order']
- sigma_range = input_sockets['σ Range']
- kappa_order = input_sockets['σ Order']
- kappa_range = input_sockets['σ Range']
- alpha_order = input_sockets['σ Order']
- alpha_range = input_sockets['σ Range']
+ case 'Full':
+ sigma_order = input_sockets['σ Order']
+ sigma_range = input_sockets['σ Range']
+ kappa_order = input_sockets['σ Order']
+ kappa_range = input_sockets['σ Range']
+ alpha_order = input_sockets['σ Order']
+ alpha_range = input_sockets['σ Range']
- has_sigma_order = not ct.FlowSignal.check(sigma_order)
- has_sigma_range = not ct.FlowSignal.check(sigma_range)
- has_kappa_order = not ct.FlowSignal.check(kappa_order)
- has_kappa_range = not ct.FlowSignal.check(kappa_range)
- has_alpha_order = not ct.FlowSignal.check(alpha_order)
- has_alpha_range = not ct.FlowSignal.check(alpha_range)
+ has_sigma_order = not FS.check(sigma_order)
+ has_sigma_range = not FS.check(sigma_range)
+ has_kappa_order = not FS.check(kappa_order)
+ has_kappa_range = not FS.check(kappa_range)
+ has_alpha_order = not FS.check(alpha_order)
+ has_alpha_range = not FS.check(alpha_range)
- if (
- has_sigma_order
- and has_sigma_range
- and has_kappa_order
- and has_kappa_range
- and has_alpha_order
- and has_alpha_range
- ):
- return (
- sigma_order
- | sigma_range
- | kappa_order
- | kappa_range
- | alpha_order
- | alpha_range
- )
+ if (
+ has_sigma_order
+ and has_sigma_range
+ and has_kappa_order
+ and has_kappa_range
+ and has_alpha_order
+ and has_alpha_range
+ ):
+ return (
+ sigma_order
+ | sigma_range
+ | kappa_order
+ | kappa_range
+ | alpha_order
+ | alpha_range
+ )
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
@@ -418,4 +398,4 @@ class PMLBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
PMLBoundCondNode,
]
-BL_NODES = {ct.NodeType.PMLBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
+BL_NODES = {ct.NodeType.PMLBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_conds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_conds.py
similarity index 73%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_conds.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_conds.py
index 8904df4..f2cf9c3 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/bound_conds.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/bound_conds.py
@@ -30,6 +30,8 @@ log = logger.get(__name__)
SSA = ct.SimSpaceAxis
+FK = ct.FlowKind
+FS = ct.FlowSignal
class BoundCondsNode(base.MaxwellSimNode):
@@ -97,27 +99,26 @@ class BoundCondsNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BCs',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_sockets_optional={
- 'X': True,
- 'Y': True,
- 'Z': True,
- '+X': True,
- '-X': True,
- '+Y': True,
- '-Y': True,
- '+Z': True,
- '-Z': True,
+ 'X',
+ 'Y',
+ 'Z',
+ '+X',
+ '-X',
+ '+Y',
+ '-Y',
+ '+Z',
+ '-Z',
},
- output_sockets={'BCs'},
- output_socket_kinds={'BCs': ct.FlowKind.Params},
+ outscks_kinds={'BCs': FK.Params},
)
def compute_bcs_value(self, input_sockets, output_sockets) -> td.BoundarySpec:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
output_params = output_sockets['BCs']
- has_output_params = not ct.FlowSignal.check(output_params)
+ has_output_params = not FS.check(output_params)
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@@ -125,9 +126,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
- has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
- has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
- has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
+ has_doubled_x = not FS.check_single(x, FS.NoFlow)
+ has_doubled_y = not FS.check_single(y, FS.NoFlow)
+ has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@@ -138,8 +139,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
- has_x_pos = not ct.FlowSignal.check(x_pos)
- has_x_neg = not ct.FlowSignal.check(x_neg)
+ has_x_pos = not FS.check(x_pos)
+ has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@@ -149,8 +150,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
- has_y_pos = not ct.FlowSignal.check(y_pos)
- has_y_neg = not ct.FlowSignal.check(y_neg)
+ has_y_pos = not FS.check(y_pos)
+ has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@@ -160,8 +161,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
- has_z_pos = not ct.FlowSignal.check(z_pos)
- has_z_neg = not ct.FlowSignal.check(z_neg)
+ has_z_pos = not FS.check(z_pos)
+ has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@@ -187,39 +188,29 @@ class BoundCondsNode(base.MaxwellSimNode):
minus=z_neg,
),
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BCs',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_socket_kinds={
- 'X': ct.FlowKind.Func,
- 'Y': ct.FlowKind.Func,
- 'Z': ct.FlowKind.Func,
- '+X': ct.FlowKind.Func,
- '-X': ct.FlowKind.Func,
- '+Y': ct.FlowKind.Func,
- '-Y': ct.FlowKind.Func,
- '+Z': ct.FlowKind.Func,
- '-Z': ct.FlowKind.Func,
- },
- input_sockets_optional={
- 'X': True,
- 'Y': True,
- 'Z': True,
- '+X': True,
- '-X': True,
- '+Y': True,
- '-Y': True,
- '+Z': True,
- '-Z': True,
+ 'X': FK.Func,
+ 'Y': FK.Func,
+ 'Z': FK.Func,
+ '+X': FK.Func,
+ '-X': FK.Func,
+ '+Y': FK.Func,
+ '-Y': FK.Func,
+ '+Z': FK.Func,
+ '-Z': FK.Func,
},
+ input_sockets_optional={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
)
- def compute_bcs_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_bcs_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@@ -227,9 +218,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
- has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
- has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
- has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
+ has_doubled_x = not FS.check_single(x, FS.NoFlow)
+ has_doubled_y = not FS.check_single(y, FS.NoFlow)
+ has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@@ -240,8 +231,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
- has_x_pos = not ct.FlowSignal.check(x_pos)
- has_x_neg = not ct.FlowSignal.check(x_neg)
+ has_x_pos = not FS.check(x_pos)
+ has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@@ -251,8 +242,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
- has_y_pos = not ct.FlowSignal.check(y_pos)
- has_y_neg = not ct.FlowSignal.check(y_neg)
+ has_y_pos = not FS.check(y_pos)
+ has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@@ -262,8 +253,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
- has_z_pos = not ct.FlowSignal.check(z_pos)
- has_z_neg = not ct.FlowSignal.check(z_neg)
+ has_z_pos = not FS.check(z_pos)
+ has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@@ -290,39 +281,29 @@ class BoundCondsNode(base.MaxwellSimNode):
),
supports_jax=False,
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BCs',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_socket_kinds={
- 'X': ct.FlowKind.Params,
- 'Y': ct.FlowKind.Params,
- 'Z': ct.FlowKind.Params,
- '+X': ct.FlowKind.Params,
- '-X': ct.FlowKind.Params,
- '+Y': ct.FlowKind.Params,
- '-Y': ct.FlowKind.Params,
- '+Z': ct.FlowKind.Params,
- '-Z': ct.FlowKind.Params,
- },
- input_sockets_optional={
- 'X': True,
- 'Y': True,
- 'Z': True,
- '+X': True,
- '-X': True,
- '+Y': True,
- '-Y': True,
- '+Z': True,
- '-Z': True,
+ 'X': FK.Params,
+ 'Y': FK.Params,
+ 'Z': FK.Params,
+ '+X': FK.Params,
+ '-X': FK.Params,
+ '+Y': FK.Params,
+ '-Y': FK.Params,
+ '+Z': FK.Params,
+ '-Z': FK.Params,
},
+ input_sockets_optional={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
)
- def compute_bcs_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_bcs_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@@ -330,9 +311,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
- has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
- has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
- has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
+ has_doubled_x = not FS.check_single(x, FS.NoFlow)
+ has_doubled_y = not FS.check_single(y, FS.NoFlow)
+ has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@@ -343,8 +324,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
- has_x_pos = not ct.FlowSignal.check(x_pos)
- has_x_neg = not ct.FlowSignal.check(x_neg)
+ has_x_pos = not FS.check(x_pos)
+ has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@@ -354,8 +335,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
- has_y_pos = not ct.FlowSignal.check(y_pos)
- has_y_neg = not ct.FlowSignal.check(y_neg)
+ has_y_pos = not FS.check(y_pos)
+ has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@@ -365,8 +346,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
- has_z_pos = not ct.FlowSignal.check(z_pos)
- has_z_neg = not ct.FlowSignal.check(z_neg)
+ has_z_pos = not FS.check(z_pos)
+ has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@@ -377,7 +358,7 @@ class BoundCondsNode(base.MaxwellSimNode):
and has_z_neg
):
return x_pos | x_neg | y_pos | y_neg | z_pos | z_neg
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
@@ -386,4 +367,4 @@ class BoundCondsNode(base.MaxwellSimNode):
BL_REGISTER = [
BoundCondsNode,
]
-BL_NODES = {ct.NodeType.BoundConds: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
+BL_NODES = {ct.NodeType.BoundConds: (ct.NodeCategory.MAXWELLSIM_SIMS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py
index 42535d1..d949edc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/fdtd_sim.py
@@ -16,13 +16,19 @@
"""Implements `FDTDSimNode`."""
+import itertools
import typing as typ
import bpy
+import jax
+import numpy as np
+import sympy as sp
+import sympy.physics.units as spu
import tidy3d as td
-import tidy3d.plugins.adjoint as tdadj
-from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import frozendict
from ... import contracts as ct
from ... import sockets
@@ -30,6 +36,44 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
+SimArray: typ.TypeAlias = frozendict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
+]
+SimArrayInfo: typ.TypeAlias = frozendict[
+ tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
+]
+
+
+class RecomputeSimInfo(bpy.types.Operator):
+ """Recompute information about the simulation."""
+
+ bl_idname = ct.OperatorType.NodeRecomputeSimInfo
+ bl_label = 'Recompute Simuation Info'
+ bl_description = (
+ 'Recompute information of a simulation attached to a `FDTDSimNode`.'
+ )
+
+ @classmethod
+ def poll(cls, context):
+ """Allow running whenever a particular FDTDSim node is available."""
+ return (
+ # Check Tidy3DWebExporter is Accessible
+ hasattr(context, 'node')
+ and hasattr(context.node, 'node_type')
+ and context.node.node_type == ct.NodeType.FDTDSim
+ )
+
+ def execute(self, context):
+ """Invalidate the `.sims` property, triggering reevaluation of all downstream information about the simulation."""
+ node = context.node
+ node.sims = bl_cache.Signal.InvalidateCache
+ return {'FINISHED'}
+
class FDTDSimNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
@@ -44,39 +88,311 @@ class FDTDSimNode(base.MaxwellSimNode):
'BCs': sockets.MaxwellBoundCondsSocketDef(),
'Domain': sockets.MaxwellSimDomainSocketDef(),
'Sources': sockets.MaxwellSourceSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
'Structures': sockets.MaxwellStructureSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
'Monitors': sockets.MaxwellMonitorSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
}
output_socket_sets: typ.ClassVar = {
'Single': {
- 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
+ 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Value),
},
- 'Func': {
- 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
+ 'Batch': {
+ 'Sims': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Array),
+ },
+ 'Lazy': {
+ 'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Func),
},
}
####################
- # - Properties
+ # - Properties: UI
####################
- differentiable: bool = bl_cache.BLField(False)
+ ui_limits: bool = bl_cache.BLField(False)
+ ui_discretization: bool = bl_cache.BLField(False)
+ ui_portability: bool = bl_cache.BLField(False)
+ ui_propagation: bool = bl_cache.BLField(False)
####################
- # - UI
+ # - Properties: Simulation
####################
- def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
- layout.prop(
- self,
- self.blfields['differentiable'],
- text='Differentiable',
- toggle=True,
- )
+ @bl_cache.cached_bl_property()
+ def sims(self) -> SimArray | None:
+ """The complete description of all simulation output objects."""
+ if self.active_socket_set == 'Single':
+ sim_value = self.compute_output('Sim', kind=FK.Value)
+ has_sim_value = not FS.check(sim_value)
+ if has_sim_value:
+ return {(): sim_value}
+
+ elif self.active_socket_set == 'Batch':
+ sim_array = self.compute_output('Sims', kind=FK.Array)
+ has_sim_array = not FS.check(sim_array)
+ if has_sim_array:
+ return sim_array
+
+ return None
+
+ ####################
+ # - Properties: Propagation
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def has_gain(self) -> SimArrayInfo | None:
+ """Whether any mediums in the simulation allow gain."""
+ if self.sims is not None:
+ return {k: sim.allow_gain for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def has_complex_fields(self) -> SimArrayInfo | None:
+ """Whether complex fields are currently used in the simulation."""
+ if self.sims is not None:
+ return {k: sim.complex_fields for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def min_wl(self) -> SimArrayInfo | None:
+ """The smallest wavelength that occurs in the simulation."""
+ if self.sims is not None:
+ return {k: sim.wvl_mat_min * spu.um for k, sim in self.sims.items()}
+ return None
+
+ ####################
+ # - Properties: Discretization
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def time_range(self) -> SimArrayInfo | None:
+ """The time range of the simulation."""
+ if self.sims is not None:
+ return {
+ k: sp.Matrix(
+ [0, spu.convert_to(sim.run_time * spu.second, spu.picosecond)]
+ )
+ for k, sim in self.sims.items()
+ }
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def freq_range(self) -> SimArrayInfo | None:
+ """The total frequency range of the simulation, across all sources."""
+ if self.sims is not None:
+ return {
+ k: spu.convert_to(
+ sp.Matrix([sim.frequency_range[0], sim.frequency_range[1]])
+ * spu.hertz,
+ spux.THz,
+ )
+ for k, sim in self.sims.items()
+ }
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def time_step(self) -> SimArrayInfo | None:
+ """The time step of the simulation."""
+ if self.sims is not None:
+ return {k: sim.dt * spu.second for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def time_steps(self) -> SimArrayInfo | None:
+ """The time step of the simulation."""
+ if self.sims is not None:
+ return {k: sim.num_time_steps for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def nyquist_step(self) -> SimArrayInfo | None:
+ """The number of time-steps needed to theoretically provide for correctly resolved sampling of the simulation grid."""
+ if self.sims is not None:
+ return {k: sim.nyquist_step for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def num_cells(self) -> SimArrayInfo | None:
+ """The number of 3D cells for which the simulation is discretized."""
+ if self.sims is not None:
+ return {k: sim.num_cells for k, sim in self.sims.items()}
+ return None
+
+ ####################
+ # - Properties: Data
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def monitor_data_sizes(self) -> SimArrayInfo | None:
+ """The total data expected to be taken by each monitors."""
+ if self.sims is not None:
+ return {k: sim.monitors_data_size for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'monitor_data_sizes'})
+ def total_monitor_data_size(self) -> SimArrayInfo | None:
+ """The total data taken by the monitors."""
+ if self.monitor_data_sizes is not None:
+ return {
+ k: sum(sizes.values()) for k, sizes in self.monitor_data_sizes.items()
+ }
+ return None
+
+ ####################
+ # - Properties: Lists
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def list_datasets(self) -> SimArrayInfo | None:
+ """List of custom datasets required by the simulation."""
+ if self.sims is not None:
+ return {k: sim.custom_datasets for k, sim in self.sims.items()}
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def list_vol_structures(self) -> SimArrayInfo | None:
+ """List of volumetric structures, where 2D mediums were converted to 3D."""
+ if self.sims is not None:
+ return {k: sim.volumetric_structures for k, sim in self.sims.items()}
+ return None
+
+ ####################
+ # - Properties: Validated
+ ####################
+ @bl_cache.cached_bl_property(depends_on={'sims'})
+ def sims_valid(self) -> SimArrayInfo | None:
+ """Whether all sims are valid."""
+ if self.sims is not None:
+ validity = {}
+ for k, sim in self.sims.items(): # noqa: B007
+ try:
+ pass ## TODO: VERY slow, batch checking is infeasible
+ # sim.validate_pre_upload(source_required=True)
+ except td.exceptions.SetupError:
+ validity[k] = False
+ else:
+ validity[k] = True
+
+ return validity
+ return None
+
+ ####################
+ # - Info
+ ####################
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'sims',
+ 'time_range',
+ 'freq_range',
+ 'min_wl',
+ 'num_cells',
+ 'time_steps',
+ 'nyquist_steps',
+ 'time_step',
+ 'sims_valid',
+ 'total_monitor_data_size',
+ 'has_gain',
+ 'has_complex_fields',
+ 'ui_limits',
+ 'ui_discretization',
+ 'ui_portability',
+ 'ui_propagation',
+ }
+ )
+ def sim_labels(self) -> SimArrayInfo | None:
+ """Pre-processed labels for efficient drawing of simulation info."""
+ if self.sims is not None:
+ sims_vals_labels = {}
+ for syms_vals in self.sims:
+ labels = []
+
+ if syms_vals:
+ labels += [
+ '|'.join(
+ [
+ f'{sym.name_pretty}={val:,.2f}'
+ for sym, val in zip(*syms_vals, strict=True)
+ ]
+ ),
+ ]
+
+ labels += [
+ ['Limits', 'ui_limits'],
+ ]
+ if self.ui_limits:
+ labels += [
+ ('max t', spux.sp_to_str(self.time_range[syms_vals][1])),
+ ('min f', spux.sp_to_str(self.freq_range[syms_vals][0])),
+ ('max f', spux.sp_to_str(self.freq_range[syms_vals][1])),
+ ('min λ', spux.sp_to_str(self.min_wl[syms_vals].n(2))),
+ ]
+
+ labels += [
+ ['Discretization', 'ui_discretization'],
+ ]
+ if self.ui_discretization:
+ labels += [
+ ('cells', f'{self.num_cells[syms_vals]:,}'),
+ ('num Δt', f'{self.time_steps[syms_vals]:,}'),
+ ('nyq Δt', f'{self.nyquist_step[syms_vals]}'),
+ ('Δt', spux.sp_to_str(self.time_step[syms_vals].n(2))),
+ ]
+
+ labels += [
+ ['Portability', 'ui_portability'],
+ ]
+ if self.ui_portability:
+ labels += [
+ ('Valid?', str(self.sims_valid[syms_vals])),
+ (
+ 'Σ mon',
+ f'{self.total_monitor_data_size[syms_vals] / 1000000:,.2f}MB',
+ ),
+ ]
+
+ labels += [
+ ['Propagation', 'ui_propagation'],
+ ]
+ if self.ui_propagation:
+ labels += [
+ ('Gain?', str(self.has_gain[syms_vals])),
+ ('𝐄𝐇 ∈ ℂ', str(self.has_complex_fields[syms_vals])),
+ ]
+
+ sims_vals_labels[syms_vals] = labels
+ return sims_vals_labels
+ return None
+
+ def draw_info(self, _, layout: bpy.types.UILayout):
+ """Draw information about the simulation, if any."""
+ row = layout.row(align=True)
+ row.alignment = 'CENTER'
+ row.label(text='Sim Info')
+ row.operator(ct.OperatorType.NodeRecomputeSimInfo, icon='FILE_REFRESH', text='')
+
+ # Simulation Info
+ if self.sim_labels is not None:
+ for labels in self.sim_labels.values():
+ box = layout.box()
+ for el in labels:
+ # Header
+ if isinstance(el, list):
+ row = box.row(align=True)
+ # row.alignment = 'EXPAND'
+ row.prop(self, self.blfields[el[1]], text=el[0], toggle=True)
+ # row.label(text=el)
+
+ split = box.split(factor=0.4)
+ col_l = split.column(align=True)
+ col_r = split.column(align=True)
+
+ # Label Pair
+ elif isinstance(el, tuple):
+ col_l.label(text=el[0])
+ col_r.label(text=el[1])
+
+ else:
+ raise TypeError
+
+ break
####################
# - Events
@@ -88,36 +404,31 @@ class FDTDSimNode(base.MaxwellSimNode):
# Loaded
props={'active_socket_set'},
output_sockets={'Sim'},
- output_socket_kinds={'Sim': ct.FlowKind.Params},
+ outscks_kinds={'Sim': FK.Params},
)
def on_any_changed(self, props, output_sockets) -> None:
- """Create loose input sockets."""
+ """Manage loose input sockets in response to symbolic simulation elements."""
+ # Loose Input Sockets
output_params = output_sockets['Sim']
- has_output_params = not ct.FlowSignal.check(output_params)
-
active_socket_set = props['active_socket_set']
if (
- active_socket_set == 'Single'
- and has_output_params
+ active_socket_set in ['Single', 'Batch']
and output_params.symbols
+ and set(self.loose_input_sockets)
+ != {sym.name for sym in output_params.symbols}
):
- if set(self.loose_input_sockets) != {
- sym.name for sym in output_params.symbols
- }:
- self.loose_input_sockets = {
- sym.name: sockets.ExprSocketDef(
- **(
- expr_info
- | {
- 'active_kind': ct.FlowKind.Value,
- 'use_value_range_swapper': (
- active_socket_set == 'Value'
- ),
- }
- )
+ self.loose_input_sockets = {
+ sym.name: sockets.ExprSocketDef(
+ **(
+ expr_info
+ | {
+ 'active_kind': FK.Value,
+ 'use_value_range_swapper': (active_socket_set == 'Batch'),
+ }
)
- for sym, expr_info in output_params.sym_expr_infos.items()
- }
+ )
+ for sym, expr_info in output_params.sym_expr_infos.items()
+ }
elif self.loose_input_sockets:
self.loose_input_sockets = {}
@@ -127,144 +438,184 @@ class FDTDSimNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Sim',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
+ props={'active_socket_set'},
+ outscks_kinds={'Sim': {FK.Func, FK.Params}},
all_loose_input_sockets=True,
- output_sockets={'Sim'},
- output_socket_kinds={'Sim': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ loose_input_sockets_kind={FK.Func, FK.Params},
)
def compute_value(
- self, loose_input_sockets, output_sockets
- ) -> ct.ParamsFlow | ct.FlowSignal:
+ self, props, loose_input_sockets, output_sockets
+ ) -> td.Simulation | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Sim'][ct.FlowKind.Func]
- output_params = output_sockets['Sim'][ct.FlowKind.Params]
+ func = output_sockets['Sim'][FK.Func]
+ params = output_sockets['Sim'][FK.Params]
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
+ has_func = not FS.check(func)
+ has_params = not FS.check(params)
- if has_output_func and has_output_params:
- return output_func.realize(
- output_params,
- symbol_values={
- sym: loose_input_sockets[sym.name]
- for sym in output_params.sorted_symbols
- },
- disallow_jax=True,
+ active_socket_set = props['active_socket_set']
+ if has_func and has_params and active_socket_set == 'Single':
+ symbol_values = {
+ sym: events.realize_known(loose_input_sockets[sym.name])
+ for sym in params.sorted_symbols
+ }
+
+ return func.realize(
+ params,
+ symbol_values=frozendict(
+ {
+ sym: events.realize_known(loose_input_sockets[sym.name])
+ for sym in params.sorted_symbols
+ }
+ ),
+ ).updated_copy(
+ attrs=dict(
+ ct.SimMetadata(
+ realizations=ct.SimRealizations(
+ syms=tuple(symbol_values.keys()),
+ vals=tuple(symbol_values.values()),
+ )
+ )
+ )
)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Array
+ ####################
+ @events.computes_output_socket(
+ 'Sims',
+ kind=FK.Array,
+ # Loaded
+ props={'active_socket_set'},
+ outscks_kinds={'Sim': {FK.Func, FK.Params}},
+ all_loose_input_sockets=True,
+ loose_input_sockets_kind={FK.Func, FK.Params},
+ )
+ def compute_array(
+ self, props, loose_input_sockets, output_sockets
+ ) -> SimArray | FS:
+ """Produce a batch of simulations as a dictionary, indexed by a twice-nested tuple matching a `SimSymbol` tuple to a corresponding tuple of values."""
+ func = output_sockets['Sim'][FK.Func]
+ params = output_sockets['Sim'][FK.Params]
+
+ active_socket_set = props['active_socket_set']
+ if active_socket_set == 'Batch':
+ # Realize Values per-Symbol
+ ## -> First, we realize however many values requested per symbol.
+ sym_datas: dict[sim_symbols.SimSymbol, list] = {}
+ for sym in params.sorted_symbols:
+ if sym.name not in loose_input_sockets:
+ return FS.FlowPending
+
+ # Realize Data for Symbol
+ ## -> This may be a single scalar/vector/matrix.
+ ## -> This may also be many _scalars_.
+ sym_data = events.realize_known(
+ loose_input_sockets[sym.name],
+ freeze=True,
+ )
+ if sym_data is None:
+ return FS.FlowPending
+
+ # Single Value per-Symbol
+ if sym.shape_len == 0:
+ sym_datas |= {sym: (sym_data,)}
+
+ # Many Values per-Symbol
+ else:
+ sym_datas |= {sym: sym_data}
+
+ # Realize Function per-Combination
+ ## -> td.Simulation requires single, specific values for all syms.
+ ## -> What we have is many specific values for each sym.
+ ## -> With a single comprehension, we resolve this difference.
+ ## -> The end-result is an annotated td.Simulation per-combo.
+ ## -> NOTE: This might be big! Such are parameter-sweeps.
+ syms = tuple(sym_datas.keys())
+ return {
+ (syms, vals): func.realize(
+ params,
+ symbol_values=frozendict(zip(syms, vals, strict=True)),
+ ).updated_copy(
+ attrs=dict(
+ ct.SimMetadata(
+ realizations=ct.SimRealizations(syms=syms, vals=vals)
+ )
+ )
+ )
+ for vals in itertools.product(*sym_datas.values())
+ }
+
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Sim',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- props={'differentiable'},
- input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
- input_socket_kinds={
- 'BCs': ct.FlowKind.Func,
- 'Domain': ct.FlowKind.Func,
- 'Sources': ct.FlowKind.Func,
- 'Structures': ct.FlowKind.Func,
- 'Monitors': ct.FlowKind.Func,
+ inscks_kinds={
+ 'BCs': FK.Func,
+ 'Domain': FK.Func,
+ 'Sources': FK.Func,
+ 'Structures': FK.Func,
+ 'Monitors': FK.Func,
},
- output_sockets={'Sim'},
- output_socket_kinds={'Sim': ct.FlowKind.Params},
)
- def compute_fdtd_sim_func(
- self, props, input_sockets, output_sockets
- ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
+ def compute_fdtd_sim_func(self, input_sockets) -> ct.FuncFlow:
"""Compute a single simulation, given that all inputs are non-symbolic."""
bounds = input_sockets['BCs']
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
monitors = input_sockets['Monitors']
- output_params = output_sockets['Sim']
- has_bounds = not ct.FlowSignal.check(bounds)
- has_sim_domain = not ct.FlowSignal.check(sim_domain)
- has_sources = not ct.FlowSignal.check(sources)
- has_structures = not ct.FlowSignal.check(structures)
- has_monitors = not ct.FlowSignal.check(monitors)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if (
- has_sim_domain
- and has_sources
- and has_structures
- and has_bounds
- and has_monitors
- and has_output_params
- ):
- differentiable = props['differentiable']
- if differentiable:
- raise NotImplementedError
-
- return (
- bounds | sim_domain | sources | structures | monitors
- ).compose_within(
- enclosing_func=lambda els: td.Simulation(
- boundary_spec=els[0],
- **els[1],
- sources=els[2],
- structures=els[3],
- monitors=els[4],
- ),
- supports_jax=False,
- )
- return ct.FlowSignal.FlowPending
+ return (bounds | sim_domain | sources | structures | monitors).compose_within(
+ lambda els: td.Simulation(
+ boundary_spec=els[0],
+ **els[1],
+ sources=els[2],
+ structures=els[3],
+ monitors=els[4],
+ ),
+ )
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Sim',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
- props={'differentiable'},
- input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
- input_socket_kinds={
- 'BCs': ct.FlowKind.Params,
- 'Domain': ct.FlowKind.Params,
- 'Sources': ct.FlowKind.Params,
- 'Structures': ct.FlowKind.Params,
- 'Monitors': ct.FlowKind.Params,
+ inscks_kinds={
+ 'BCs': FK.Params,
+ 'Domain': FK.Params,
+ 'Sources': FK.Params,
+ 'Structures': FK.Params,
+ 'Monitors': FK.Params,
},
)
- def compute_fdtd_sim_params(
- self, props, input_sockets
- ) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
- """Compute a single simulation, given that all inputs are non-symbolic."""
+ def compute_params(self, input_sockets) -> td.Simulation | FS:
+ """Compute all function parameters needed to create the simulation."""
+ # Compute Output Parameters
bounds = input_sockets['BCs']
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
monitors = input_sockets['Monitors']
- has_bounds = not ct.FlowSignal.check(bounds)
- has_sim_domain = not ct.FlowSignal.check(sim_domain)
- has_sources = not ct.FlowSignal.check(sources)
- has_structures = not ct.FlowSignal.check(structures)
- has_monitors = not ct.FlowSignal.check(monitors)
-
- if (
- has_bounds
- and has_sim_domain
- and has_sources
- and has_structures
- and has_monitors
- ):
- return bounds | sim_domain | sources | structures | monitors
- return ct.FlowSignal.FlowPending
+ return bounds | sim_domain | sources | structures | monitors
####################
# - Blender Registration
####################
BL_REGISTER = [
+ RecomputeSimInfo,
FDTDSimNode,
]
BL_NODES = {ct.NodeType.FDTDSim: (ct.NodeCategory.MAXWELLSIM_SIMS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
index 7c525ff..e6659c3 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py
@@ -22,8 +22,8 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
@@ -31,6 +31,11 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
class SimDomainNode(base.MaxwellSimNode):
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
@@ -41,24 +46,24 @@ class SimDomainNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Duration': sockets.ExprSocketDef(
- physical_type=spux.PhysicalType.Time,
+ physical_type=PT.Time,
default_unit=spu.picosecond,
default_value=5,
abs_min=0,
),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Length,
+ mathtype=MT.Real,
+ physical_type=PT.Length,
default_unit=spu.micrometer,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Length,
+ mathtype=MT.Real,
+ physical_type=PT.Length,
default_unit=spu.micrometer,
- default_value=sp.Matrix([1, 1, 1]),
+ default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0.001,
),
'Grid': sockets.MaxwellSimGridSocketDef(),
@@ -77,89 +82,75 @@ class SimDomainNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Domain',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- output_sockets={'Domain'},
- output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ outscks_kinds={
+ 'Domain': {FK.Func, FK.Params},
+ },
)
- def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Domain'][ct.FlowKind.Func]
- output_params = output_sockets['Domain'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Domain'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Domain',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
- 'Duration': ct.FlowKind.Func,
- 'Center': ct.FlowKind.Func,
- 'Size': ct.FlowKind.Func,
- 'Grid': ct.FlowKind.Func,
- 'Ambient Medium': ct.FlowKind.Func,
+ 'Duration': FK.Func,
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Grid': FK.Func,
+ 'Ambient Medium': FK.Func,
+ },
+ scale_input_sockets={
+ 'Duration': ct.UNITS_TIDY3D,
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
},
)
- def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
size = input_sockets['Size']
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
-
- has_duration = not ct.FlowSignal.check(duration)
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_grid = not ct.FlowSignal.check(grid)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_duration and has_center and has_size and has_grid and has_medium:
- return (
- duration.scale_to_unit_system(ct.UNITS_TIDY3D)
- | center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | size.scale_to_unit_system(ct.UNITS_TIDY3D)
- | grid
- | medium
- ).compose_within(
- lambda els: {
- 'run_time': els[0],
- 'center': els[1].flatten().tolist(),
- 'size': els[2].flatten().tolist(),
- 'grid_spec': els[3],
- 'medium': els[4],
- },
- supports_jax=False,
- )
- return ct.FlowSignal.FlowPending
+ return (duration | center | size | grid | medium).compose_within(
+ lambda els: {
+ 'run_time': els[0],
+ 'center': els[1].flatten().tolist(),
+ 'size': els[2].flatten().tolist(),
+ 'grid_spec': els[3],
+ 'medium': els[4],
+ },
+ supports_jax=False,
+ )
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Domain',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
- 'Duration': ct.FlowKind.Params,
- 'Center': ct.FlowKind.Params,
- 'Size': ct.FlowKind.Params,
- 'Grid': ct.FlowKind.Params,
- 'Ambient Medium': ct.FlowKind.Params,
+ 'Duration': FK.Params,
+ 'Center': FK.Params,
+ 'Size': FK.Params,
+ 'Grid': FK.Params,
+ 'Ambient Medium': FK.Params,
},
)
- def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
@@ -167,22 +158,14 @@ class SimDomainNode(base.MaxwellSimNode):
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
- has_duration = not ct.FlowSignal.check(duration)
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_grid = not ct.FlowSignal.check(grid)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_duration and has_center and has_size and has_grid and has_medium:
- return duration | center | size | grid | medium
- return ct.FlowSignal.FlowPending
+ return duration | center | size | grid | medium
####################
# - Preview
####################
@events.computes_output_socket(
'Domain',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
@@ -192,35 +175,37 @@ class SimDomainNode(base.MaxwellSimNode):
@events.on_value_changed(
# Trigger
- socket_name={'Center', 'Size'},
+ socket_name={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
run_on_init=True,
# Loaded
- input_sockets={'Center', 'Size'},
managed_objs={'modifier'},
- output_sockets={'Domain'},
- output_socket_kinds={'Domain': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
+ },
)
- def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
+ def on_input_changed(self, managed_objs, input_sockets) -> None:
"""Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
- output_params = output_sockets['Domain']
- center = input_sockets['Center']
+ center = events.realize_preview(input_sockets['Center'])
+ size = events.realize_preview(input_sockets['Size'])
- has_output_params = not ct.FlowSignal.check(output_params)
- has_center = not ct.FlowSignal.check(center)
-
- if has_center and has_output_params and not output_params.symbols:
- # Push Loose Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
- 'unit_system': ct.UNITS_BLENDER,
- 'inputs': {
- 'Size': input_sockets['Size'],
- },
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
+ 'inputs': {
+ 'Size': size,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid.py
index 5b7d824..6603d7c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid.py
@@ -14,8 +14,164 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `SimGridNode`."""
+
+import typing as typ
+
+import sympy.physics.units as spu
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+
+from ... import contracts as ct
+from ... import sockets
+from .. import base, events
+
+log = logger.get(__name__)
+
+
+SSA = ct.SimSpaceAxis
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+
+class SimGridNode(base.MaxwellSimNode):
+ """Provides a hub for joining custom simulation domain boundary conditions by-axis."""
+
+ node_type = ct.NodeType.SimGrid
+ bl_label = 'Sim Grid'
+
+ ####################
+ # - Sockets
+ ####################
+ input_sockets: typ.ClassVar = {
+ 'X': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
+ 'Y': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
+ 'Z': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
+ }
+ input_socket_sets: typ.ClassVar = {
+ 'Relative': {},
+ 'Absolute': {
+ 'WL': sockets.ExprSocketDef(
+ default_unit=spu.nm,
+ default_value=500,
+ abs_min=0,
+ abs_min_closed=False,
+ )
+ },
+ }
+ output_sockets: typ.ClassVar = {
+ 'Grid': sockets.MaxwellSimGridSocketDef(),
+ }
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Grid',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={'Grid': {FK.Func, FK.Params}},
+ )
+ def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
+ """Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
+ value = events.realize_known(output_sockets['Grid'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Grid',
+ kind=FK.Func,
+ # Loaded
+ props={'active_socket_set'},
+ inscks_kinds={
+ 'X': FK.Func,
+ 'Y': FK.Func,
+ 'Z': FK.Func,
+ 'WL': FK.Func,
+ },
+ input_sockets_optional={'WL'},
+ scale_input_sockets={
+ 'WL': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_grid_func(self, props, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation grid lazily, at the specified wavelength."""
+ # Deduce "Doubledness"
+ ## -> A "doubled" axis defines the same bound cond both ways
+ x = input_sockets['X']
+ y = input_sockets['Y']
+ z = input_sockets['Z']
+
+ wl = input_sockets['WL']
+
+ active_socket_set = props['active_socket_set']
+ common_func = x | y | z
+ match active_socket_set:
+ case 'Absolute' if not FS.check(wl):
+ return (common_func | wl).compose_within(
+ lambda els: td.GridSpec(
+ grid_x=els[0], grid_y=els[1], grid_z=els[2], wavelength=els[3]
+ )
+ )
+
+ case 'Relative':
+ return common_func.compose_within(
+ lambda els: td.GridSpec(
+ grid_x=els[0],
+ grid_y=els[1],
+ grid_z=els[2],
+ )
+ )
+
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Grid',
+ kind=FK.Params,
+ # Loaded
+ props={'active_socket_set'},
+ inscks_kinds={
+ 'X': FK.Params,
+ 'Y': FK.Params,
+ 'Z': FK.Params,
+ 'WL': FK.Params,
+ },
+ input_sockets_optional={'WL'},
+ )
+ def compute_bcs_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
+ # Deduce "Doubledness"
+ ## -> A "doubled" axis defines the same bound cond both ways
+ x = input_sockets['X']
+ y = input_sockets['Y']
+ z = input_sockets['Z']
+
+ wl = input_sockets['WL']
+
+ active_socket_set = props['active_socket_set']
+ common_params = x | y | z
+ match active_socket_set:
+ case 'Relative':
+ return common_params
+
+ case 'Absolute' if not FS.check(wl):
+ return common_params | wl
+
+ return FS.FlowPending
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ SimGridNode,
+]
+BL_NODES = {ct.NodeType.SimGrid: (ct.NodeCategory.MAXWELLSIM_SIMS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/automatic_sim_grid_axis.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/automatic_sim_grid_axis.py
index 5b7d824..163c074 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/automatic_sim_grid_axis.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/automatic_sim_grid_axis.py
@@ -14,8 +14,126 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `AutoSimGridAxisNode`."""
+
+import typing as typ
+
+import sympy.physics.units as spu
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+
+SSA = ct.SimSpaceAxis
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+
+class AutoSimGridAxisNode(base.MaxwellSimNode):
+ """Declare a uniform grid along a simulation axis."""
+
+ node_type = ct.NodeType.AutoSimGridAxis
+ bl_label = 'Auto Grid Axis'
+
+ ####################
+ # - Sockets
+ ####################
+ input_sockets: typ.ClassVar = {
+ 'min N/λ': sockets.ExprSocketDef(
+ default_value=10,
+ ),
+ 'min Δℓ': sockets.ExprSocketDef(
+ default_unit=spu.nm,
+ default_value=0,
+ abs_min=0,
+ ),
+ 'max ratio': sockets.ExprSocketDef(
+ default_value=1.4,
+ abs_min=1,
+ ),
+ }
+ output_sockets: typ.ClassVar = {
+ 'Grid Axis': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
+ }
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={'Grid Axis': {FK.Func, FK.Params}},
+ )
+ def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
+ """Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
+ value = events.realize_known(output_sockets['Grid Axis'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'min N/λ': FK.Func,
+ 'max ratio': FK.Func,
+ 'min Δℓ': FK.Func,
+ },
+ scale_input_sockets={
+ 'min Δℓ': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_grid_func(self, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation grid lazily, at the specified wavelength."""
+ min_steps_per_wl = input_sockets['min N/λ']
+ max_consecutive_ratio = input_sockets['max ratio']
+ min_length = input_sockets['min Δℓ']
+
+ return (min_steps_per_wl | max_consecutive_ratio | min_length).compose_within(
+ lambda els: td.AutoGrid(
+ min_steps_per_wvl=els[0],
+ max_scale=els[1],
+ dl_min=els[2],
+ )
+ )
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'min N/λ': FK.Params,
+ 'max ratio': FK.Params,
+ 'min Δℓ': FK.Params,
+ },
+ )
+ def compute_grid_params(self, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation grid lazily, at the specified wavelength."""
+ min_steps_per_wl = input_sockets['min N/λ']
+ max_consecutive_ratio = input_sockets['max ratio']
+ min_length = input_sockets['min Δℓ']
+
+ return min_steps_per_wl | max_consecutive_ratio | min_length
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ AutoSimGridAxisNode,
+]
+BL_NODES = {ct.NodeType.AutoSimGridAxis: (ct.NodeCategory.MAXWELLSIM_SIMS_SIMGRIDAXES)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/uniform_sim_grid_axis.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/uniform_sim_grid_axis.py
index 5b7d824..3df430a 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/uniform_sim_grid_axis.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_grid_axes/uniform_sim_grid_axis.py
@@ -14,8 +14,111 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `UniformSimGridAxisNode`."""
+
+import typing as typ
+
+import sympy.physics.units as spu
+import tidy3d as td
+
+from blender_maxwell.utils import logger
+
+from .... import contracts as ct
+from .... import sockets
+from ... import base, events
+
+log = logger.get(__name__)
+
+
+SSA = ct.SimSpaceAxis
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+
+class UniformSimGridAxisNode(base.MaxwellSimNode):
+ """Declare a uniform grid along a simulation axis."""
+
+ node_type = ct.NodeType.UniformSimGridAxis
+ bl_label = 'Uniform Grid Axis'
+
+ ####################
+ # - Sockets
+ ####################
+ input_sockets: typ.ClassVar = {
+ 'Δℓ': sockets.ExprSocketDef(
+ default_unit=spu.nm,
+ default_value=10,
+ ),
+ }
+ output_sockets: typ.ClassVar = {
+ 'Grid Axis': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
+ }
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={'Grid Axis': {FK.Func, FK.Params}},
+ )
+ def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
+ """Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
+ value = events.realize_known(output_sockets['Grid Axis'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'Δℓ': FK.Func,
+ },
+ scale_input_sockets={
+ 'Δℓ': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_grid_func(self, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation grid lazily, at the specified wavelength."""
+ # Deduce "Doubledness"
+ ## -> A "doubled" axis defines the same bound cond both ways
+ dl = input_sockets['Δℓ']
+
+ return dl.compose_within(
+ lambda _dl: td.UniformGrid(dl=_dl),
+ )
+
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Grid Axis',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'Δℓ': FK.Params,
+ },
+ )
+ def compute_grid_params(self, input_sockets) -> ct.ParamsFlow | FS:
+ """Compute the simulation grid lazily, at the specified wavelength."""
+ # Deduce "Doubledness"
+ ## -> A "doubled" axis defines the same bound cond both ways
+ dl = input_sockets['Δℓ']
+ return dl # noqa: RET504
+
+
####################
# - Blender Registration
####################
-BL_REGISTER = []
-BL_NODES = {}
+BL_REGISTER = [
+ UniformSimGridAxisNode,
+]
+BL_NODES = {
+ ct.NodeType.UniformSimGridAxis: (ct.NodeCategory.MAXWELLSIM_SIMS_SIMGRIDAXES)
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/__init__.py
new file mode 100644
index 0000000..29a213f
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/__init__.py
@@ -0,0 +1,32 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from . import (
+ eme_solver,
+ fdtd_solver,
+ mode_solver,
+)
+
+BL_REGISTER = [
+ *fdtd_solver.BL_REGISTER,
+ *mode_solver.BL_REGISTER,
+ *eme_solver.BL_REGISTER,
+]
+BL_NODES = {
+ **fdtd_solver.BL_NODES,
+ **mode_solver.BL_NODES,
+ **eme_solver.BL_NODES,
+}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/eme_solver.py
similarity index 79%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/__init__.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/eme_solver.py
index ce82240..eb0c936 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/bounds/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/eme_solver.py
@@ -14,13 +14,5 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from . import bound_cond_nodes, bound_conds
-
-BL_REGISTER = [
- *bound_conds.BL_REGISTER,
- *bound_cond_nodes.BL_REGISTER,
-]
-BL_NODES = {
- **bound_conds.BL_NODES,
- **bound_cond_nodes.BL_NODES,
-}
+BL_REGISTER = []
+BL_NODES = {}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/fdtd_solver.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/fdtd_solver.py
new file mode 100644
index 0000000..4df3805
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/fdtd_solver.py
@@ -0,0 +1,337 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements `FDTDSolverNode`."""
+
+import typing as typ
+
+import bpy
+
+from blender_maxwell.services import tdcloud
+from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from ... import contracts as ct
+from ... import sockets
+from .. import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
+
+####################
+# - Operators
+####################
+class RunSimulation(bpy.types.Operator):
+ """Run a Tidy3D simulation given to a `FDTDSolverNode`."""
+
+ bl_idname = ct.OperatorType.NodeRunSimulation
+ bl_label = 'Run Sim'
+ bl_description = 'Run the currently tracked simulation task'
+
+ @classmethod
+ def poll(cls, context):
+ """Allow running when there are runnable tasks."""
+ return (
+ # Check Tidy3D Cloud
+ tdcloud.IS_AUTHENTICATED
+ # Check FDTDSolverNode is Accessible
+ and hasattr(context, 'node')
+ and hasattr(context.node, 'node_type')
+ and context.node.node_type == ct.NodeType.FDTDSolver
+ # Check Task is Runnable
+ and context.node.are_tasks_runnable
+ )
+
+ def execute(self, context):
+ """Run all uploaded, runnable tasks."""
+ node = context.node
+
+ for cloud_task in node.cloud_tasks:
+ log.debug('Submitting Cloud Task %s', cloud_task.task_id)
+ cloud_task.submit()
+
+ return {'FINISHED'}
+
+
+class ReloadTrackedTask(bpy.types.Operator):
+ """Reload information of the selected task in a `FDTDSolverNode`."""
+
+ bl_idname = ct.OperatorType.NodeReloadTrackedTask
+ bl_label = 'Reload Tracked Tidy3D Cloud Task'
+ bl_description = 'Reload the currently tracked simulation task'
+
+ @classmethod
+ def poll(cls, context):
+ """Always allow reloading tasks."""
+ return (
+ # Check Tidy3D Cloud
+ tdcloud.IS_AUTHENTICATED
+ # Check FDTDSolverNode is Accessible
+ and hasattr(context, 'node')
+ and hasattr(context.node, 'node_type')
+ and context.node.node_type == ct.NodeType.FDTDSolver
+ )
+
+ def execute(self, context):
+ """Reload all tasks in all folders for which there are uploaded tasks in the node."""
+ node = context.node
+ for folder_id in {cloud_task.folder_id for cloud_task in node.cloud_tasks}:
+ tdcloud.TidyCloudTasks.update_tasks(folder_id)
+
+ return {'FINISHED'}
+
+
+####################
+# - Node
+####################
+class FDTDSolverNode(base.MaxwellSimNode):
+ """Solve an FDTD simulation problem using the Tidy3D cloud solver."""
+
+ node_type = ct.NodeType.FDTDSolver
+ bl_label = 'FDTD Solver'
+
+ input_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ should_exist=True,
+ ),
+ },
+ 'Batch': {
+ 'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Array,
+ should_exist=True,
+ ),
+ },
+ }
+ output_socket_sets: typ.ClassVar = {
+ 'Single': {
+ 'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
+ should_exist=True,
+ ),
+ },
+ 'Batch': {
+ 'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
+ active_kind=FK.Array,
+ should_exist=True,
+ ),
+ },
+ }
+
+ ####################
+ # - Properties: Incoming InfoFlow
+ ####################
+ @events.on_value_changed(
+ # Trigger
+ socket_name={'Cloud Task': FK.Value, 'Cloud Tasks': FK.Array},
+ )
+ def on_cloud_tasks_changed(self) -> None: # noqa: D102
+ self.cloud_tasks = bl_cache.Signal.InvalidateCache
+
+ @bl_cache.cached_bl_property()
+ def cloud_tasks(self) -> list[tdcloud.CloudTask] | None:
+ """Retrieve the current cloud tasks from the input.
+
+ If one can't be loaded, return None.
+ """
+ if self.active_socket_set == 'Single':
+ cloud_task_single = self._compute_input(
+ 'Cloud Task',
+ kind=FK.Value,
+ )
+ has_cloud_task_single = not FS.check(cloud_task_single)
+ if has_cloud_task_single and cloud_task_single is not None:
+ return [cloud_task_single]
+
+ if self.active_socket_set == 'Batch':
+ cloud_task_array = self._compute_input(
+ 'Cloud Tasks',
+ kind=FK.Array,
+ )
+ has_cloud_task_array = not FS.check(cloud_task_array)
+ if has_cloud_task_array and all(
+ cloud_task is not None for cloud_task in cloud_task_array
+ ):
+ return cloud_task_array
+
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
+ def task_infos(self) -> list[tdcloud.CloudTaskInfo | None] | None:
+ """Retrieve the current cloud task information from the input socket.
+
+ If it can't be loaded, return None.
+ """
+ if self.cloud_tasks is not None:
+ task_infos = [
+ tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
+ for cloud_task in self.cloud_tasks
+ ]
+ if task_infos:
+ return task_infos
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
+ def task_progress(self) -> tuple[float | None, float | None] | None:
+ """Retrieve current progress percent (in terms of time steps) and current field decay (normalized to max value).
+
+ Either entry can be None, denoting that they aren't yet available.
+ """
+ if self.cloud_tasks is not None:
+ task_progress = [
+ cloud_task.get_running_info() for cloud_task in self.cloud_tasks
+ ]
+ if task_progress:
+ return task_progress
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'task_progress'})
+ def total_progress_pct(self) -> float | None:
+ """Retrieve current progress percent, averaged across all running tasks."""
+ if self.task_progress is not None and all(
+ progress[0] is not None and progress[1] is not None
+ for progress in self.task_progress
+ ):
+ return sum([progress[0] for progress in self.task_progress]) / len(
+ self.task_progress
+ )
+ return None
+
+ @bl_cache.cached_bl_property(depends_on={'task_infos'})
+ def are_tasks_runnable(self) -> bool:
+ """Checks whether all conditions are satisfied to be able to actually run a simulation."""
+ return self.task_infos is not None and all(
+ task_info is not None and task_info.status == 'draft'
+ for task_info in self.task_infos
+ )
+
+ ####################
+ # - UI
+ ####################
+ def draw_operators(self, _, layout):
+ """Draw the button that runs the active simulation(s)."""
+ # Row: Run Sim Buttons
+ row = layout.row(align=True)
+ row.operator(
+ ct.OperatorType.NodeRunSimulation,
+ text='Run Sim',
+ )
+
+ def draw_info(self, _, layout):
+ """Draw information about the running simulation."""
+ tdcloud.draw_cloud_status(layout)
+
+ # Cloud Task Info
+ if self.task_infos is not None and self.task_progress is not None:
+ for task_info, task_progress in zip(
+ self.task_infos, self.task_progress, strict=True
+ ):
+ if self.task_infos is not None:
+ # Header
+ row = layout.row()
+ row.alignment = 'CENTER'
+ row.label(text='Task Info')
+
+ # Task Run Progress
+ row = layout.row(align=True)
+ progress_pct = (
+ task_progress[0]
+ if task_progress is not None and task_progress[0] is not None
+ else 0.0
+ )
+ row.progress(
+ factor=progress_pct,
+ type='BAR',
+ text=f'{task_info.status.capitalize()}',
+ )
+ row.operator(
+ ct.OperatorType.NodeReloadTrackedTask,
+ text='',
+ icon='FILE_REFRESH',
+ )
+
+ # Task Information
+ box = layout.box()
+
+ split = box.split(factor=0.4)
+
+ col = split.column(align=False)
+ col.label(text='Status')
+
+ col = split.column(align=False)
+ col.alignment = 'RIGHT'
+ col.label(text=task_info.status.capitalize())
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Cloud Task',
+ kind=FK.Value,
+ # Loaded
+ props={'cloud_tasks', 'task_infos'},
+ )
+ def compute_cloud_task(self, props) -> str:
+ """A single simulation data object, when there only is one."""
+ cloud_tasks = props['cloud_tasks']
+ task_infos = props['task_infos']
+ if (
+ cloud_tasks is not None
+ and len(cloud_tasks) == 1
+ and task_infos is not None
+ and task_infos[0].status == 'success'
+ ):
+ return cloud_tasks[0]
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Array
+ ####################
+ @events.computes_output_socket(
+ 'Sim Datas',
+ kind=FK.Array,
+ # Loaded
+ props={'cloud_tasks', 'task_infos'},
+ )
+ def compute_cloud_tasks(self, props) -> list[tdcloud.CloudTask]:
+ """All simulation data objects, for when there are more than one.
+
+ Generally part of the same batch.
+ """
+ cloud_tasks = props['cloud_tasks']
+ task_infos = props['task_infos']
+ if (
+ cloud_tasks is not None
+ and len(cloud_tasks) > 1
+ and task_infos is not None
+ and all(task_info.status == 'success' for task_info in task_infos)
+ ):
+ return cloud_tasks
+ return FS.FlowPending
+
+
+####################
+# - Blender Registration
+####################
+BL_REGISTER = [
+ RunSimulation,
+ ReloadTrackedTask,
+ FDTDSolverNode,
+]
+BL_NODES = {ct.NodeType.FDTDSolver: (ct.NodeCategory.MAXWELLSIM_SOLVERS)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/mode_solver.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/mode_solver.py
new file mode 100644
index 0000000..eb0c936
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/solvers/mode_solver.py
@@ -0,0 +1,18 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+BL_REGISTER = []
+BL_NODES = {}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
index 94eb827..3f2076d 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/gaussian_beam_source.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `GaussianBeamSourceNode`."""
+
import typing as typ
import bpy
@@ -29,6 +31,8 @@ from ... import managed_objs, sockets
from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
class GaussianBeamSourceNode(base.MaxwellSimNode):
@@ -57,13 +61,13 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([1, 1]),
+ default_value=sp.ImmutableMatrix([1, 1]),
),
'Waist Dist': sockets.ExprSocketDef(
mathtype=spux.MathType.Real,
@@ -80,7 +84,7 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Angle,
- default_value=sp.Matrix([0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0]),
),
'Pol ∡': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Angle,
@@ -106,129 +110,225 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ """Draw choices of injection axis and direction."""
layout.prop(self, self.blfields['injection_axis'], expand=True)
layout.prop(self, self.blfields['injection_direction'], expand=True)
# layout.prop(self, self.blfields['num_freqs'], text='f Points')
## TODO: UI is a bit crowded already!
####################
- # - Outputs
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Angled Source',
- props={'sim_node_name', 'injection_axis', 'injection_direction', 'num_freqs'},
- input_sockets={
- 'Temporal Shape',
- 'Center',
- 'Size',
- 'Waist Dist',
- 'Waist Radius',
- 'Spherical',
- 'Pol ∡',
- },
- unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
- scale_input_sockets={
- 'Center': 'Tidy3DUnits',
- 'Size': 'Tidy3DUnits',
- 'Waist Dist': 'Tidy3DUnits',
- 'Waist Radius': 'Tidy3DUnits',
- 'Spherical': 'Tidy3DUnits',
- 'Pol ∡': 'Tidy3DUnits',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Angled Source': {FK.Func, FK.Params},
},
)
- def compute_source(self, props, input_sockets, unit_systems):
- size_2d = input_sockets['Size']
- size = {
- ct.SimSpaceAxis.X: (0, *size_2d),
- ct.SimSpaceAxis.Y: (size_2d[0], 0, size_2d[1]),
- ct.SimSpaceAxis.Z: (*size_2d, 0),
- }[props['injection_axis']]
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
+ output_func = output_sockets['Structure'][FK.Func]
+ output_params = output_sockets['Structure'][FK.Params]
- # Display the results
- return td.GaussianBeam(
- name=props['sim_node_name'],
- center=input_sockets['Center'],
- size=size,
- source_time=input_sockets['Temporal Shape'],
- num_freqs=props['num_freqs'],
- direction=props['injection_direction'].plus_or_minus,
- angle_theta=input_sockets['Spherical'][0],
- angle_phi=input_sockets['Spherical'][1],
- pol_angle=input_sockets['Pol ∡'],
- waist_radius=input_sockets['Waist Radius'],
- waist_distance=input_sockets['Waist Dist'],
- ## NOTE: Waist is place at this signed dist along neg. direction
+ if not output_params.symbols:
+ return output_func.realize(output_params, disallow_jax=True)
+ return ct.FlowSignal.FlowPending
+
+ ####################
+ # - FlowKind.Value
+ ####################
+ @events.computes_output_socket(
+ 'Angled Source',
+ kind=FK.Func,
+ # Loaded
+ props={'sim_node_name', 'injection_axis', 'injection_direction', 'num_freqs'},
+ inscks_kinds={
+ 'Temporal Shape': FK.Func,
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Waist Dist': FK.Func,
+ 'Waist Radius': FK.Func,
+ 'Spherical': FK.Func,
+ 'Pol ∡': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
+ 'Waist Dist': ct.UNITS_TIDY3D,
+ 'Waist Radius': ct.UNITS_TIDY3D,
+ 'Spherical': ct.UNITS_TIDY3D,
+ 'Pol ∡': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_func(self, props, input_sockets) -> td.GaussianBeam:
+ """Compute a function that returns a gaussian beam object, given sufficient parameters."""
+ inj_dir = props['injection_axis']
+
+ def size(size_2d: tuple[float, float]) -> tuple[float, float, float]:
+ return {
+ ct.SimSpaceAxis.X: (0, *size_2d),
+ ct.SimSpaceAxis.Y: (size_2d[0], 0, size_2d[1]),
+ ct.SimSpaceAxis.Z: (*size_2d, 0),
+ }[inj_dir]
+
+ center = input_sockets['Center']
+ size_2d = input_sockets['Size']
+ temporal_shape = input_sockets['Temporal Shape']
+ spherical = input_sockets['Spherical']
+ pol_ang = input_sockets['Pol ∡']
+ waist_radius = input_sockets['Waist Radius']
+ waist_dist = input_sockets['Waist Dist']
+
+ sim_node_name = props['sim_node_name']
+ num_freqs = props['num_freqs']
+ return (
+ center
+ | size_2d
+ | temporal_shape
+ | spherical
+ | pol_ang
+ | waist_radius
+ | waist_dist
+ ).compose_within(
+ lambda els: td.GaussianBeam(
+ name=sim_node_name,
+ center=els[0],
+ size=size(els[1]),
+ source_time=els[2],
+ num_freqs=num_freqs,
+ direction=inj_dir.plus_or_minus,
+ angle_theta=els[3].item(0),
+ angle_phi=els[3].item(1),
+ pol_angle=els[4],
+ waist_radius=els[5],
+ waist_distance=els[6],
+ )
)
####################
- # - Preview - Changes to Input Sockets
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Structure',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'Temporal Shape': FK.Func,
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ 'Waist Dist': FK.Func,
+ 'Waist Radius': FK.Func,
+ 'Spherical': FK.Func,
+ 'Pol ∡': FK.Func,
+ },
+ )
+ def compute_params(self, input_sockets) -> ct.ParamsFlow:
+ """Propagate function parameters from inputs."""
+ center = input_sockets['Center']
+ size_2d = input_sockets['Size']
+ temporal_shape = input_sockets['Temporal Shape']
+ spherical = input_sockets['Spherical']
+ pol_ang = input_sockets['Pol ∡']
+ waist_radius = input_sockets['Waist Radius']
+ waist_dist = input_sockets['Waist Dist']
+
+ return (
+ center
+ | size_2d
+ | temporal_shape
+ | spherical
+ | pol_ang
+ | waist_radius
+ | waist_dist
+ )
+
+ ####################
+ # - Events: Preview
####################
@events.computes_output_socket(
'Angled Source',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
+ outscks_kinds={'Angled Source': ct.FlowKind.Func},
+ output_sockets_optional={'Angled Source'},
)
- def compute_previews(self, props):
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
+ def compute_previews(self, props, output_sockets):
+ """Update the preview state when the name or output function change."""
+ if not FS.check(output_sockets['Structure']):
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
+ return ct.PreviewsFlow()
@events.on_value_changed(
# Trigger
socket_name={
- 'Center',
- 'Size',
- 'Waist Dist',
- 'Waist Radius',
- 'Spherical',
- 'Pol ∡',
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ 'Waist Dist': {FK.Func, FK.Params},
+ 'Waist Radius': {FK.Func, FK.Params},
+ 'Spherical': {FK.Func, FK.Params},
+ 'Pol ∡': {FK.Func, FK.Params},
},
prop_name={'injection_axis', 'injection_direction'},
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'injection_axis', 'injection_direction'},
- input_sockets={
- 'Temporal Shape',
- 'Center',
- 'Size',
- 'Waist Dist',
- 'Waist Radius',
- 'Spherical',
- 'Pol ∡',
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ 'Waist Dist': {FK.Func, FK.Params},
+ 'Waist Radius': {FK.Func, FK.Params},
+ 'Spherical': {FK.Func, FK.Params},
+ 'Pol ∡': {FK.Func, FK.Params},
},
- unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={
- 'Center': 'BlenderUnits',
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
+ 'Waist Dist': ct.UNITS_BLENDER,
+ 'Waist Radius': ct.UNITS_BLENDER,
+ 'Spherical': ct.UNITS_BLENDER,
+ 'Pol ∡': ct.UNITS_BLENDER,
},
)
- def on_inputs_changed(self, managed_objs, props, input_sockets, unit_systems):
- size_2d = input_sockets['Size']
+ def on_previewable_chnaged(self, managed_objs, props, input_sockets):
+ """Update the preview when relevant inputs change."""
+ center = events.realize_preview(input_sockets['Center'])
+ size_2d = events.realize_preview(input_sockets['Size'])
+ spherical = events.realize_preview(input_sockets['Spherical'])
+ pol_ang = events.realize_preview(input_sockets['Pol ∡'])
+ waist_radius = events.realize_preview(input_sockets['Waist Radius'])
+ waist_dist = events.realize_preview(input_sockets['Waist Dist'])
+
+ # Retrieve Properties
+ inj_dir = props['injection_axis']
size = {
- ct.SimSpaceAxis.X: sp.Matrix([0, *size_2d]),
- ct.SimSpaceAxis.Y: sp.Matrix([size_2d[0], 0, size_2d[1]]),
- ct.SimSpaceAxis.Z: sp.Matrix([*size_2d, 0]),
+ ct.SimSpaceAxis.X: sp.ImmutableMatrix([0, *size_2d]),
+ ct.SimSpaceAxis.Y: sp.ImmutableMatrix([size_2d[0], 0, size_2d[1]]),
+ ct.SimSpaceAxis.Z: sp.ImmutableMatrix([*size_2d, 0]),
}[props['injection_axis']]
- # Push Input Values to GeoNodes Modifier
+ # Push Updated Values
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourceGaussianBeam),
- 'unit_system': unit_systems['BlenderUnits'],
'inputs': {
# Orientation
- 'Inj Axis': props['injection_axis'].axis,
- 'Direction': props['injection_direction'].true_or_false,
- 'theta': input_sockets['Spherical'][0],
- 'phi': input_sockets['Spherical'][1],
- 'Pol Angle': input_sockets['Pol ∡'],
+ 'Inj Axis': inj_dir.axis,
+ 'Direction': inj_dir.true_or_false,
+ 'theta': spherical[0],
+ 'phi': spherical[1],
+ 'Pol Angle': pol_ang,
# Gaussian Beam
'Size': size,
- 'Waist Dist': input_sockets['Waist Dist'],
- 'Waist Radius': input_sockets['Waist Radius'],
+ 'Waist Radius': waist_radius,
+ 'Waist Dist': waist_dist,
},
},
- location=input_sockets['Center'],
+ location=center,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
index 3f60b9e..a541ccc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/plane_wave_source.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `PlaneWaveSourceNode`."""
+
import typing as typ
import bpy
@@ -30,6 +32,11 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
class PlaneWaveSourceNode(base.MaxwellSimNode):
"""An infinite-extent angled source simulating an plane wave with linear polarization.
@@ -52,26 +59,26 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
# - Sockets
####################
input_sockets: typ.ClassVar = {
- 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
+ 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(active_kind=FK.Func),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([0, 0, 0]),
+ mathtype=MT.Real,
+ physical_type=PT.Length,
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Spherical': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Angle,
- default_value=sp.Matrix([0, 0]),
+ mathtype=MT.Real,
+ physical_type=PT.Angle,
+ default_value=sp.ImmutableMatrix([0, 0]),
),
'Pol ∡': sockets.ExprSocketDef(
- physical_type=spux.PhysicalType.Angle,
+ physical_type=PT.Angle,
default_value=0,
),
}
output_sockets: typ.ClassVar = {
- 'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=ct.FlowKind.Func),
+ 'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@@ -88,6 +95,7 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ """Draw choices of injection axis and direction."""
layout.prop(self, self.blfields['injection_axis'], expand=True)
layout.prop(self, self.blfields['injection_direction'], expand=True)
@@ -96,175 +104,145 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Angled Source',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- output_sockets={'Angled Source'},
- output_socket_kinds={'Angled Source': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ outscks_kinds={'Angled Source': {FK.Func, FK.Params}},
)
- def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Angled Source'][ct.FlowKind.Func]
- output_params = output_sockets['Angled Source'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params, disallow_jax=True)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Angled Source'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Angled Source',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'sim_node_name', 'injection_axis', 'injection_direction'},
- input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
- input_socket_kinds={
- 'Temporal Shape': ct.FlowKind.Func,
- 'Center': ct.FlowKind.Func,
- 'Spherical': ct.FlowKind.Func,
- 'Pol ∡': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Temporal Shape': FK.Func,
+ 'Center': FK.Func,
+ 'Spherical': FK.Func,
+ 'Pol ∡': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Spherical': ct.UNITS_TIDY3D,
+ 'Pol ∡': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> None:
+ """Compute a lazy function for the plane wave source."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
- has_center = not ct.FlowSignal.check(center)
- has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
- has_spherical = not ct.FlowSignal.check(spherical)
- has_pol_ang = not ct.FlowSignal.check(pol_ang)
+ name = props['sim_node_name']
+ inj_dir = props['injection_direction'].plus_or_minus
+ size = {
+ ct.SimSpaceAxis.X: (0, td.inf, td.inf),
+ ct.SimSpaceAxis.Y: (td.inf, 0, td.inf),
+ ct.SimSpaceAxis.Z: (td.inf, td.inf, 0),
+ }[props['injection_axis']]
- if has_center and has_temporal_shape and has_spherical and has_pol_ang:
- name = props['sim_node_name']
- inj_dir = props['injection_direction'].plus_or_minus
- size = {
- ct.SimSpaceAxis.X: (0, td.inf, td.inf),
- ct.SimSpaceAxis.Y: (td.inf, 0, td.inf),
- ct.SimSpaceAxis.Z: (td.inf, td.inf, 0),
- }[props['injection_axis']]
-
- return (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | temporal_shape
- | spherical.scale_to_unit_system(ct.UNITS_TIDY3D)
- | pol_ang.scale_to_unit_system(ct.UNITS_TIDY3D)
- ).compose_within(
- lambda els: td.PlaneWave(
- name=name,
- center=els[0].flatten().tolist(),
- size=size,
- source_time=els[1],
- direction=inj_dir,
- angle_theta=els[2][0].item(0),
- angle_phi=els[2][1].item(0),
- pol_angle=els[3],
- )
+ return (center | temporal_shape | spherical | pol_ang).compose_within(
+ lambda els: td.PlaneWave(
+ name=name,
+ center=els[0].flatten().tolist(),
+ size=size,
+ source_time=els[1],
+ direction=inj_dir,
+ angle_theta=els[2].item(0),
+ angle_phi=els[2].item(1),
+ pol_angle=els[3],
)
- return ct.FlowSignal.FlowPending
+ )
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Angled Source',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
- input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
- input_socket_kinds={
- 'Temporal Shape': ct.FlowKind.Params,
- 'Center': ct.FlowKind.Params,
- 'Spherical': ct.FlowKind.Params,
- 'Pol ∡': ct.FlowKind.Params,
+ inscks_kinds={
+ 'Temporal Shape': FK.Params,
+ 'Center': FK.Params,
+ 'Spherical': FK.Params,
+ 'Pol ∡': FK.Params,
},
)
def compute_params(self, input_sockets) -> None:
+ """Compute the function parameters of the lazy function."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
- has_center = not ct.FlowSignal.check(center)
- has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
- has_spherical = not ct.FlowSignal.check(spherical)
- has_pol_ang = not ct.FlowSignal.check(pol_ang)
-
- if has_center and has_temporal_shape and has_spherical and has_pol_ang:
- return center | temporal_shape | spherical | pol_ang
- return ct.FlowSignal.FlowPending
+ return center | temporal_shape | spherical | pol_ang
####################
# - Preview - Changes to Input Sockets
####################
@events.computes_output_socket(
'Angled Source',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
- output_sockets={'Angled Source'},
- output_socket_kinds={'Angled Source': ct.FlowKind.Params},
)
- def compute_previews(self, props, output_sockets):
- output_params = output_sockets['Angled Source']
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and not output_params.symbols:
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
- return ct.PreviewsFlow()
+ def compute_previews(self, props):
+ """Mark the plane wave as participating in the 3D preview."""
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
- socket_name={'Center', 'Spherical', 'Pol ∡'},
+ socket_name={
+ 'Center': {FK.Func, FK.Params},
+ 'Spherical': {FK.Func, FK.Params},
+ 'Pol ∡': {FK.Func, FK.Params},
+ },
prop_name={'injection_axis', 'injection_direction'},
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'injection_axis', 'injection_direction'},
- input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
- output_sockets={'Angled Source'},
- output_socket_kinds={'Angled Source': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Spherical': {FK.Func, FK.Params},
+ 'Pol ∡': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ 'Spherical': ct.UNITS_BLENDER,
+ 'Pol ∡': ct.UNITS_BLENDER,
+ },
)
- def on_previewable_changed(
- self, managed_objs, props, input_sockets, output_sockets
- ):
- center = input_sockets['Center']
- spherical = input_sockets['Spherical']
- pol_ang = input_sockets['Pol ∡']
- output_params = output_sockets['Angled Source']
+ def on_previewable_changed(self, managed_objs, props, input_sockets):
+ """Push changes in the inputs to the pol/center."""
+ center = events.realize_preview(input_sockets['Center'])
+ spherical = events.realize_preview(input_sockets['Spherical'])
+ pol_ang = events.realize_preview(input_sockets['Pol ∡'])
- has_center = not ct.FlowSignal.check(center)
- has_spherical = not ct.FlowSignal.check(spherical)
- has_pol_ang = not ct.FlowSignal.check(pol_ang)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if (
- has_center
- and has_spherical
- and has_pol_ang
- and has_output_params
- and not output_params.symbols
- ):
- # Push Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.SourcePlaneWave),
- 'unit_system': ct.UNITS_BLENDER,
- 'inputs': {
- 'Inj Axis': props['injection_axis'].axis,
- 'Direction': props['injection_direction'].true_or_false,
- 'theta': spherical[0],
- 'phi': spherical[1],
- 'Pol Angle': pol_ang,
- },
+ # Push Input Values to GeoNodes Modifier
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.SourcePlaneWave),
+ 'inputs': {
+ 'Inj Axis': props['injection_axis'].axis,
+ 'Direction': props['injection_direction'].true_or_false,
+ 'theta': spherical.item(0),
+ 'phi': spherical.item(1),
+ 'Pol Angle': pol_ang,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
index 0457790..1c6f8cd 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/point_dipole_source.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `PointDipoleSourceNode`."""
+
import typing as typ
import bpy
@@ -30,8 +32,15 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+PT = spux.PhysicalType
+
class PointDipoleSourceNode(base.MaxwellSimNode):
+ """A point dipole with E or H oriented linear polarization along an axis-aligned angle."""
+
node_type = ct.NodeType.PointDipoleSource
bl_label = 'Point Dipole Source'
use_sim_node_name = True
@@ -43,16 +52,16 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Length,
- default_value=sp.Matrix([0, 0, 0]),
+ mathtype=MT.Real,
+ physical_type=PT.Length,
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Interpolate': sockets.BoolSocketDef(
default_value=True,
),
}
output_sockets: typ.ClassVar = {
- 'Source': sockets.MaxwellSourceSocketDef(active_kind=ct.FlowKind.Func),
+ 'Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@@ -68,6 +77,7 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ """Draw choices of polarization direction."""
layout.prop(self, self.blfields['pol'], expand=True)
####################
@@ -75,36 +85,33 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Source',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
output_sockets={'Source'},
- output_socket_kinds={'Source': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ output_socket_kinds={'Source': {FK.Func, FK.Params}},
)
- def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Source'][ct.FlowKind.Func]
- output_params = output_sockets['Source'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params, disallow_jax=True)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Source'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Source',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'pol'},
- input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
- input_socket_kinds={
- 'Temporal Shape': ct.FlowKind.Func,
- 'Center': ct.FlowKind.Func,
- 'Interpolate': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Temporal Shape': FK.Func,
+ 'Center': FK.Func,
+ 'Interpolate': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.Box:
@@ -113,121 +120,95 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
- has_center = not ct.FlowSignal.check(center)
- has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
- has_interpolate = not ct.FlowSignal.check(interpolate)
+ pol = props['pol']
- if has_temporal_shape and has_center and has_interpolate:
- pol = props['pol']
- return (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | temporal_shape
- | interpolate
- ).compose_within(
- enclosing_func=lambda els: td.PointDipole(
- center=els[0].flatten().tolist(),
- source_time=els[1],
- interpolate=els[2],
- polarization=pol.name,
- )
+ return (center | temporal_shape | interpolate).compose_within(
+ lambda els: td.PointDipole(
+ center=els[0].flatten().tolist(),
+ source_time=els[1],
+ interpolate=els[2],
+ polarization=pol.name,
)
- return ct.FlowSignal.FlowPending
+ )
+ return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Source',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
- 'Temporal Shape': ct.FlowKind.Params,
- 'Center': ct.FlowKind.Params,
- 'Interpolate': ct.FlowKind.Params,
+ 'Temporal Shape': FK.Params,
+ 'Center': FK.Params,
+ 'Interpolate': FK.Params,
},
)
def compute_params(
self,
input_sockets,
- ) -> td.PointDipole | ct.FlowSignal:
- """Compute the point dipole source, given that all inputs are non-symbolic."""
+ ) -> td.PointDipole:
+ """Compute the function parameters of the lazy function."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
- has_center = not ct.FlowSignal.check(center)
- has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
- has_interpolate = not ct.FlowSignal.check(interpolate)
-
- if has_center and has_temporal_shape and has_interpolate:
- return center | temporal_shape | interpolate
- return ct.FlowSignal.FlowPending
+ return center | temporal_shape | interpolate
####################
# - Preview
####################
@events.computes_output_socket(
'Source',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
- output_sockets={'Source'},
- output_socket_kinds={'Source': ct.FlowKind.Params},
)
- def compute_previews(self, props, output_sockets):
- output_params = output_sockets['Source']
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and not output_params.symbols:
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
- return ct.PreviewsFlow()
+ def compute_previews(self, props):
+ """Mark the point dipole as participating in the 3D preview."""
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
- socket_name={'Center'},
+ socket_name={'Center': {FK.Func, FK.Params}},
prop_name='pol',
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'pol'},
- input_sockets={'Center'},
- output_sockets={'Source'},
- output_socket_kinds={'Source': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ },
)
- def on_previewable_changed(
- self, managed_objs, props, input_sockets, output_sockets
- ) -> None:
+ def on_previewable_changed(self, managed_objs, props, input_sockets) -> None:
+ """Push changes in the inputs to the pol/center."""
SFP = ct.SimFieldPols
- center = input_sockets['Center']
- output_params = output_sockets['Source']
+ center = events.realize_preview(input_sockets['Center'])
+ axis = {
+ SFP.Ex: 0,
+ SFP.Ey: 1,
+ SFP.Ez: 2,
+ SFP.Hx: 0,
+ SFP.Hy: 1,
+ SFP.Hz: 2,
+ }[props['pol']]
- has_center = not ct.FlowSignal.check(center)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_center and has_output_params and not output_params.symbols:
- axis = {
- SFP.Ex: 0,
- SFP.Ey: 1,
- SFP.Ez: 2,
- SFP.Hx: 0,
- SFP.Hy: 1,
- SFP.Hz: 2,
- }[props['pol']]
-
- # Push Loose Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.SourcePointDipole),
- 'unit_system': ct.UNITS_BLENDER,
- 'inputs': {
- 'Axis': axis,
- },
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.SourcePointDipole),
+ 'inputs': {
+ 'Axis': axis,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
index c307143..c5ae22a 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shape.py
@@ -36,11 +36,15 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+PT = spux.PhysicalType
+
# Select Default Time Unit for Envelope
## -> Chosen to align with the default envelope_time_unit.
## -> This causes it to be correct from the start.
-t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
+t_def = sim_symbols.t(PT.Time.valid_units[0])
class TemporalShapeNode(base.MaxwellSimNode):
@@ -54,19 +58,19 @@ class TemporalShapeNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'μ Freq': sockets.ExprSocketDef(
- physical_type=spux.PhysicalType.Freq,
+ physical_type=PT.Freq,
default_unit=spux.THz,
default_value=500,
),
'σ Freq': sockets.ExprSocketDef(
- physical_type=spux.PhysicalType.Freq,
+ physical_type=PT.Freq,
default_unit=spux.THz,
default_value=200,
),
'max E': sockets.ExprSocketDef(
mathtype=spux.MathType.Complex,
- physical_type=spux.PhysicalType.EField,
- default_value=1 + 0j,
+ physical_type=PT.EField,
+ default_value=1,
),
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
}
@@ -77,8 +81,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
'Constant': {},
'Symbolic': {
't Range': sockets.ExprSocketDef(
- active_kind=ct.FlowKind.Range,
- physical_type=spux.PhysicalType.Time,
+ active_kind=FK.Range,
+ physical_type=PT.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
@@ -92,11 +96,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
}
output_sockets: typ.ClassVar = {
- 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
- }
-
- managed_obj_types: typ.ClassVar = {
- 'plot': managed_objs.ManagedBLImage,
+ 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(active_kind=FK.Func),
}
####################
@@ -110,7 +110,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
"""Compute all valid time units."""
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
- for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
+ for i, unit in enumerate(PT.Time.valid_units)
]
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
@@ -178,8 +178,9 @@ class TemporalShapeNode(base.MaxwellSimNode):
)
def on_envelope_time_unit_changed(self, props) -> None:
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
- active_socket_set = props['active_socket_set']
envelope_time_unit = props['envelope_time_unit']
+
+ active_socket_set = props['active_socket_set']
if active_socket_set == 'Symbolic':
bl_socket = self.inputs['Envelope']
wanted_t_sym = sim_symbols.t(envelope_time_unit)
@@ -192,48 +193,40 @@ class TemporalShapeNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Temporal Shape',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- output_sockets={'Temporal Shape'},
- output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ outscks_kinds={'Temporal Shape': {FK.Func, FK.Params}},
)
- def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute a single temporal shape."""
- output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
- output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Temporal Shape'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind: Func
####################
@events.computes_output_socket(
'Temporal Shape',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
props={'active_socket_set'},
- input_sockets={
- 'max E',
- 'μ Freq',
- 'σ Freq',
- 'Offset Time',
- 'Remove DC',
- 't Range',
- 'Envelope',
+ inscks_kinds={
+ 'μ Freq': FK.Func,
+ 'σ Freq': FK.Func,
+ 'max E': FK.Func,
+ 'Offset Time': FK.Func,
+ 'Remove DC': FK.Value,
+ 't Range': FK.Func,
+ 'Envelope': {FK.Func, FK.Params},
},
- input_socket_kinds={
- 'max E': ct.FlowKind.Func,
- 'μ Freq': ct.FlowKind.Func,
- 'σ Freq': ct.FlowKind.Func,
- 'Offset Time': ct.FlowKind.Func,
- 'Remove DC': ct.FlowKind.Value,
- 't Range': ct.FlowKind.Func,
- 'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
+ input_sockets_optional={'Remove DC', 't Range', 'Envelope'},
+ scale_input_sockets={
+ 'μ Freq': ct.UNITS_TIDY3D,
+ 'σ Freq': ct.UNITS_TIDY3D,
+ 'max E': ct.UNITS_TIDY3D,
+ 't Range': ct.UNITS_TIDY3D,
},
)
def compute_temporal_shape_func(
@@ -247,158 +240,114 @@ class TemporalShapeNode(base.MaxwellSimNode):
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
- has_mean_freq = not ct.FlowSignal.check(mean_freq)
- has_std_freq = not ct.FlowSignal.check(std_freq)
- has_max_e = not ct.FlowSignal.check(max_e)
- has_offset = not ct.FlowSignal.check(offset)
+ remove_dc = input_sockets['Remove DC']
+ t_range = input_sockets['t Range']
+ envelope = input_sockets['Envelope'][FK.Func]
+ envelope_params = input_sockets['Envelope'][FK.Params]
- if has_mean_freq and has_std_freq and has_max_e and has_offset:
- common_func = (
- max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
- | mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
- | std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
- | offset ## Already unitless
- )
- match props['active_socket_set']:
- case 'Pulse':
- remove_dc = input_sockets['Remove DC']
+ common_func = max_e | mean_freq | std_freq | offset
+ active_socket_set = props['active_socket_set']
+ match active_socket_set:
+ case 'Pulse' if not FS.check(remove_dc):
+ return common_func.compose_within(
+ lambda els: td.GaussianPulse(
+ amplitude=complex(els[0]).real,
+ phase=complex(els[0]).imag,
+ freq0=els[1],
+ fwidth=els[2],
+ offset=els[3],
+ remove_dc_component=remove_dc,
+ ),
+ )
- has_remove_dc = not ct.FlowSignal.check(remove_dc)
+ case 'Constant':
+ return common_func.compose_within(
+ lambda els: td.ContinuousWave(
+ amplitude=complex(els[0]).real,
+ phase=complex(els[0]).imag,
+ freq0=els[1],
+ fwidth=els[2],
+ offset=els[3],
+ ),
+ )
- if has_remove_dc:
- return common_func.compose_within(
- lambda els: td.GaussianPulse(
- amplitude=complex(els[0]).real,
- phase=complex(els[0]).imag,
- freq0=els[1],
- fwidth=els[2],
- offset=els[3],
- remove_dc_component=remove_dc,
- ),
- )
+ case 'Symbolic' if (
+ not FS.check(t_range)
+ and not FS.check(envelope)
+ and not FS.check(envelope_params)
+ and len(envelope_params.symbols) == 1
+ and next(iter(envelope_params.symbols)).physical_type is PT.Time
+ and any(sym.physical_type is PT.Time for sym in envelope_params.symbols)
+ ):
+ envelope_time_unit = next(iter(envelope_params.symbols)).unit
- case 'Constant':
- return common_func.compose_within(
- lambda els: td.GaussianPulse(
- amplitude=complex(els[0]).real,
- phase=complex(els[0]).imag,
- freq0=els[1],
- fwidth=els[2],
- offset=els[3],
+ # Deduce Partially Realized Envelope Function
+ ## -> We need a pure-numerical function w/pre-realized stuff baked in.
+ ## -> 'generate_realizer' does this for us.
+ envelope_realizer = envelope.generate_realizer(envelope_params)
+
+ # Compose w/Envelope Function
+ ## -> First, the numerical time values must be converted.
+ ## -> This ensures that the raw array is compatible w/the envelope.
+ ## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
+ ## -> Because of the checks, we've guaranteed that all this is correct.
+ return (
+ common_func
+ | t_range
+ | t_range.scale_to_unit(envelope_time_unit).compose_within(
+ lambda t: envelope_realizer(t)
+ )
+ ).compose_within(
+ lambda els: td.CustomSourceTime(
+ amplitude=complex(els[0]).real,
+ phase=complex(els[0]).imag,
+ freq0=els[1],
+ fwidth=els[2],
+ offset=els[3],
+ source_time_dataset=td_TimeDataset(
+ values=td_TimeDataArray(
+ els[5], coords={'t': np.array(els[4])}
+ )
),
)
+ )
- case 'Symbolic':
- t_range = input_sockets['t Range']
- envelope = input_sockets['Envelope'][ct.FlowKind.Func]
- envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
-
- has_t_range = not ct.FlowSignal.check(t_range)
- has_envelope = not ct.FlowSignal.check(envelope)
- has_envelope_params = not ct.FlowSignal.check(envelope_params)
-
- if (
- has_t_range
- and has_envelope
- and has_envelope_params
- and len(envelope_params.symbols) == 1
- ## TODO: Allow unrealized envelope symbols
- and any(
- sym.physical_type is spux.PhysicalType.Time
- for sym in envelope_params.symbols
- )
- ):
- envelope_time_unit = next(
- sym.unit
- for sym in envelope_params.symbols
- if sym.physical_type is spux.PhysicalType.Time
- )
-
- # Deduce Partially Realized Envelope Function
- ## -> We need a pure-numerical function w/pre-realized stuff baked in.
- ## -> 'realize_partial' does this for us.
- envelope_realizer = envelope.realize_partial(envelope_params)
-
- # Compose w/Envelope Function
- ## -> First, the numerical time values must be converted.
- ## -> This ensures that the raw array is compatible w/the envelope.
- ## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
- ## -> Because of the checks, we've guaranteed that all this is correct.
- return (
- common_func ## 1 | freq0, 2 | fwidth, 3 | offset
- | t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
- | t_range.scale_to_unit(envelope_time_unit).compose_within(
- lambda t: envelope_realizer(t)
- ) ## 5
- ).compose_within(
- lambda els: td.CustomSourceTime(
- amplitude=complex(els[0]).real,
- phase=complex(els[0]).imag,
- freq0=els[1],
- fwidth=els[2],
- offset=els[3],
- source_time_dataset=td_TimeDataset(
- values=td_TimeDataArray(
- els[5], coords={'t': np.array(els[4])}
- )
- ),
- )
- )
-
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - FlowKind: Params
####################
@events.computes_output_socket(
'Temporal Shape',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
- props={'active_socket_set', 'envelope_time_unit'},
- input_sockets={
- 'max E',
- 'μ Freq',
- 'σ Freq',
- 'Offset Time',
- 't Range',
- },
- input_socket_kinds={
- 'max E': ct.FlowKind.Params,
- 'μ Freq': ct.FlowKind.Params,
- 'σ Freq': ct.FlowKind.Params,
- 'Offset Time': ct.FlowKind.Params,
- 't Range': ct.FlowKind.Params,
+ props={'active_socket_set'},
+ inscks_kinds={
+ 'μ Freq': FK.Params,
+ 'σ Freq': FK.Params,
+ 'max E': FK.Params,
+ 'Offset Time': FK.Params,
+ 't Range': FK.Params,
},
+ input_sockets_optional={'t Range'},
)
- def compute_temporal_shape_params(
- self,
- props,
- input_sockets,
- ) -> td.GaussianPulse:
+ def compute_temporal_shape_params(self, props, input_sockets) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
- has_mean_freq = not ct.FlowSignal.check(mean_freq)
- has_std_freq = not ct.FlowSignal.check(std_freq)
- has_max_e = not ct.FlowSignal.check(max_e)
- has_offset = not ct.FlowSignal.check(offset)
+ t_range = input_sockets['t Range']
- if has_mean_freq and has_std_freq and has_max_e and has_offset:
- common_params = max_e | mean_freq | std_freq | offset
- match props['active_socket_set']:
- case 'Pulse' | 'Constant':
- return common_params
+ common_params = max_e | mean_freq | std_freq | offset
+ match props['active_socket_set']:
+ case 'Pulse' | 'Constant':
+ return common_params
- case 'Symbolic':
- t_range = input_sockets['t Range']
- has_t_range = not ct.FlowSignal.check(t_range)
-
- if has_t_range:
- return common_params | t_range | t_range
- return ct.FlowSignal.FlowPending
+ case 'Symbolic' if not FS.check(t_range):
+ return common_params | t_range | t_range
+ return FS.FlowPending
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
index 6669a50..0f099c8 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/geonodes_structure.py
@@ -14,14 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `GeoNodesStructureNode`."""
+
+import functools
import typing as typ
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from ... import bl_socket_map, managed_objs, sockets
from ... import contracts as ct
@@ -29,8 +32,14 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class GeoNodesStructureNode(base.MaxwellSimNode):
+ """A generic mesh structure defined by an arbitrary Geometry Nodes tree."""
+
node_type = ct.NodeType.GeoNodesStructure
bl_label = 'GeoNodes Structure'
use_sim_node_name = True
@@ -40,44 +49,106 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
- 'Medium': sockets.MaxwellMediumSocketDef(),
+ 'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
}
output_sockets: typ.ClassVar = {
- 'Structure': sockets.MaxwellStructureSocketDef(),
+ 'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
- 'modifier': managed_objs.ManagedBLModifier,
+ 'preview_mesh': managed_objs.ManagedBLModifier,
+ 'mesh': managed_objs.ManagedBLModifier,
}
####################
- # - Outputs
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
- input_sockets={'Medium'},
- managed_objs={'modifier'},
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Structure': {FK.Func, FK.Params},
+ },
)
- def compute_structure(
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
+ value = events.realize_known(output_sockets['Structure'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Structure',
+ kind=FK.Func,
+ # Loaded
+ props={'sim_node_name'},
+ managed_objs={'mesh'},
+ inscks_kinds={
+ 'GeoNodes': FK.Value,
+ 'Medium': FK.Func,
+ 'Center': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ },
+ all_loose_input_sockets=True,
+ loose_input_sockets_kind=FK.Func,
+ scale_loose_input_sockets=ct.UNITS_BLENDER,
+ )
+ def compute_func(
self,
- input_sockets,
managed_objs,
+ props,
+ input_sockets,
+ loose_input_sockets,
) -> td.Structure:
- """Computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`."""
+ """Lazily computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`."""
## TODO: mesh_as_arrays might not take the Center into account.
## - Alternatively, Tidy3D might have a way to transform?
- mesh_as_arrays = managed_objs['modifier'].mesh_as_arrays
- return td.Structure(
- geometry=td.TriangleMesh.from_vertices_faces(
- mesh_as_arrays['verts'],
- mesh_as_arrays['faces'],
- ),
- medium=input_sockets['Medium'],
+
+ mesh = managed_objs['mesh']
+ geonodes = input_sockets['GeoNodes']
+ medium = input_sockets['Medium']
+ center = input_sockets['Center']
+
+ sim_node_name = props['sim_node_name']
+ gn_inputs = list(loose_input_sockets.keys())
+
+ def verts_faces(els: tuple[typ.Any]) -> dict[str, typ.Any]:
+ # Push Realized Values to Managed Mesh
+ mesh.bl_modifier(
+ 'NODES',
+ {
+ 'node_group': geonodes,
+ 'inputs': dict(zip(gn_inputs, els[2:], strict=True)),
+ },
+ location=els[1],
+ )
+
+ # Extract Vert/Face Data
+ mesh_as_arrays = mesh.mesh_as_arrays
+ return (mesh_as_arrays['verts'], mesh_as_arrays['faces'])
+
+ loose_sck_values = functools.reduce(
+ lambda a, b: a | b, loose_input_sockets.values()
+ )
+ return (medium | center | loose_sck_values).compose_within(
+ lambda els: td.Structure(
+ name=sim_node_name,
+ geometry=td.TriangleMesh.from_vertices_faces(
+ *verts_faces(els),
+ ),
+ medium=els[0],
+ )
)
####################
@@ -99,61 +170,73 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
# - Events: Swapped GN Node Tree
####################
@events.on_value_changed(
- socket_name={'GeoNodes'},
+ socket_name={
+ 'GeoNodes': FK.Value,
+ },
# Loaded
- managed_objs={'modifier'},
- input_sockets={'GeoNodes'},
+ inscks_kinds={
+ 'GeoNodes': FK.Value,
+ },
)
def on_input_changed(
self,
- managed_objs,
input_sockets,
) -> None:
- """Declares new loose input sockets in response to a new GeoNodes tree (if any)."""
- geonodes = input_sockets['GeoNodes']
- has_geonodes = not ct.FlowSignal.check(geonodes) and geonodes is not None
+ """Synchronizes the GeoNodes tree input sockets to the loose input sockets of this node.
- if has_geonodes:
+ Utilizes `bl_socket_map.sockets_from_geonodes` to generate, primarily, `ExprSocketDef`s for use with `self.loose_input_sockets.
+ """
+ geonodes = input_sockets['GeoNodes']
+ if geonodes is not None:
# Fill the Loose Input Sockets
## -> The SocketDefs contain the default values from the interface.
- log.info(
+ log.debug(
'Initializing GeoNodes Structure Node "%s" from GeoNodes Group "%s"',
self.bl_label,
str(geonodes),
)
self.loose_input_sockets = bl_socket_map.sockets_from_geonodes(geonodes)
- ## -> The loose socket creation triggers 'on_input_socket_changed'
-
elif self.loose_input_sockets:
self.loose_input_sockets = {}
- managed_objs['modifier'].free()
####################
# - Events: Preview
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews(self, props):
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
+ """Mark the box structure as participating in the preview."""
+ sim_node_name = props['sim_node_name']
+ return ct.PreviewsFlow(bl_object_names={sim_node_name + '_1'})
@events.on_value_changed(
# Trigger
- socket_name={'Center', 'GeoNodes'}, ## MUST run after on_input_changed
+ socket_name={
+ 'GeoNodes': FK.Value,
+ 'Center': {FK.Func, FK.Params},
+ },
any_loose_input_socket=True,
+ run_on_init=True,
# Loaded
- managed_objs={'modifier'},
- input_sockets={'Center', 'GeoNodes'},
+ managed_objs={'preview_mesh'},
+ inscks_kinds={
+ 'GeoNodes': FK.Value,
+ 'Center': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ },
all_loose_input_sockets=True,
- unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
- scale_input_sockets={'Center': 'BlenderUnits'},
+ loose_input_sockets_kind={FK.Func, FK.Params},
+ scale_loose_input_sockets=ct.UNITS_BLENDER,
)
def on_input_socket_changed(
- self, managed_objs, input_sockets, loose_input_sockets, unit_systems
+ self, managed_objs, input_sockets, loose_input_sockets
) -> None:
"""Pushes any change in GeoNodes-bound input sockets to the GeoNodes modifier.
@@ -163,19 +246,16 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
Also pushes the `Center:Value` socket to govern the object's center in 3D space.
"""
geonodes = input_sockets['GeoNodes']
- has_geonodes = not ct.FlowSignal.check(geonodes) and geonodes is not None
+ center = events.realize_preview(input_sockets['Center'])
- if has_geonodes:
- # Push Loose Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': geonodes,
- 'inputs': loose_input_sockets,
- 'unit_system': unit_systems['BlenderUnits'],
- },
- location=input_sockets['Center'],
- )
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': geonodes,
+ 'inputs': loose_input_sockets,
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
index 754ac29..c69c8d5 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py
@@ -14,16 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `BoxStructureNode`."""
+
import typing as typ
-import bpy
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
-import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
@@ -32,9 +32,13 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class BoxStructureNode(base.MaxwellSimNode):
- """A generic, differentiable box structure with configurable size and center."""
+ """A generic box structure with configurable size and center."""
node_type = ct.NodeType.BoxStructure
bl_label = 'Box Structure'
@@ -48,198 +52,141 @@ class BoxStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.nanometer,
- default_value=sp.Matrix([500, 500, 500]),
+ default_value=sp.ImmutableMatrix([500, 500, 500]),
abs_min=0.001,
),
}
output_sockets: typ.ClassVar = {
- 'Structure': sockets.MaxwellStructureSocketDef(active_kind=ct.FlowKind.Func),
+ 'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
'modifier': managed_objs.ManagedBLModifier,
}
- ####################
- # - Properties
- ####################
- differentiable: bool = bl_cache.BLField(False)
-
- ####################
- # - UI
- ####################
- def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
- layout.prop(
- self,
- self.blfields['differentiable'],
- text='Differentiable',
- toggle=True,
- )
-
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
- output_sockets={'Structure'},
- output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ outscks_kinds={
+ 'Structure': {FK.Func, FK.Params},
+ },
)
- def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Structure'][ct.FlowKind.Func]
- output_params = output_sockets['Structure'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params, disallow_jax=True)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Structure'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- props={'differentiable'},
- input_sockets={'Medium', 'Center', 'Size'},
- input_socket_kinds={
- 'Medium': ct.FlowKind.Func,
- 'Center': ct.FlowKind.Func,
- 'Size': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Medium': FK.Func,
+ 'Center': FK.Func,
+ 'Size': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Size': ct.UNITS_TIDY3D,
},
)
- def compute_structure_func(self, props, input_sockets) -> td.Box:
- """Compute a possibly-differentiable function, producing a box structure from the input parameters."""
+ def compute_func(self, input_sockets) -> td.Box:
+ """Compute a function, producing a box structure from the input parameters."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_center and has_size and has_medium:
- differentiable = props['differentiable']
- if differentiable:
- return (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | size.scale_to_unit_system(ct.UNITS_TIDY3D)
- | medium
- ).compose_within(
- lambda els: tdadj.JaxStructure(
- geometry=tdadj.JaxBox(
- center_jax=tuple(els[0].flatten()),
- size_jax=tuple(els[1].flatten()),
- ),
- medium=els[2],
- ),
- supports_jax=True,
- )
- return (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | size.scale_to_unit_system(ct.UNITS_TIDY3D)
- | medium
- ).compose_within(
- lambda els: td.Structure(
- geometry=td.Box(
- center=els[0].flatten().tolist(),
- size=els[1].flatten().tolist(),
- ),
- medium=els[2],
+ return (center | size | medium).compose_within(
+ lambda els: td.Structure(
+ geometry=td.Box(
+ center=els[0].flatten().tolist(),
+ size=els[1].flatten().tolist(),
),
- supports_jax=False,
- )
- return ct.FlowSignal.FlowPending
+ medium=els[2],
+ ),
+ )
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
- input_sockets={'Medium', 'Center', 'Size'},
- input_socket_kinds={
- 'Medium': ct.FlowKind.Params,
- 'Center': ct.FlowKind.Params,
- 'Size': ct.FlowKind.Params,
+ inscks_kinds={
+ 'Medium': FK.Params,
+ 'Center': FK.Params,
+ 'Size': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
+ """Aggregate the function parameters needed by the box."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_center and has_size and has_medium:
- return center | size | medium
- return ct.FlowSignal.FlowPending
+ return center | size | medium
####################
# - Events: Preview
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
- output_sockets={'Structure'},
- output_socket_kinds={'Structure': ct.FlowKind.Params},
)
- def compute_previews(self, props, output_sockets):
- """Mark the managed preview object when recursively linked to a viewer."""
- output_params = output_sockets['Structure']
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and not output_params.symbols:
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
- return ct.PreviewsFlow()
+ def compute_previews(self, props):
+ """Mark the box structure as participating in the preview."""
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
- socket_name={'Center', 'Size'},
+ socket_name={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
run_on_init=True,
# Loaded
managed_objs={'modifier'},
- input_sockets={'Center', 'Size'},
- output_sockets={'Structure'},
- output_socket_kinds={'Structure': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ 'Size': ct.UNITS_BLENDER,
+ },
)
- def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
- center = input_sockets['Center']
- size = input_sockets['Size']
- output_params = output_sockets['Structure']
+ def on_previewable_changed(self, managed_objs, input_sockets) -> None:
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ size = events.realize_preview(input_sockets['Size'])
- has_center = not ct.FlowSignal.check(center)
- has_size = not ct.FlowSignal.check(size)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_center and has_size and has_output_params and not output_params.symbols:
- # Push Loose Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
- 'unit_system': ct.UNITS_BLENDER,
- 'inputs': {
- 'Size': size,
- },
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
+ 'inputs': {
+ 'Size': size,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
index 1177dbd..99a6de1 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `CylinderStructureNode`."""
+
import typing as typ
import sympy as sp
@@ -21,8 +23,8 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets
@@ -30,6 +32,10 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class CylinderStructureNode(base.MaxwellSimNode):
"""A generic cylinder structure with configurable radius and height."""
@@ -46,7 +52,7 @@ class CylinderStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Radius': sockets.ExprSocketDef(
default_unit=spu.nanometer,
@@ -58,7 +64,7 @@ class CylinderStructureNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
- 'Structure': sockets.MaxwellStructureSocketDef(active_kind=ct.FlowKind.Func),
+ 'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@@ -70,36 +76,35 @@ class CylinderStructureNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Value,
+ kind=FK.Value,
# Loaded
output_sockets={'Structure'},
- output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
+ output_socket_kinds={'Structure': {FK.Func, FK.Params}},
)
- def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
- output_func = output_sockets['Structure'][ct.FlowKind.Func]
- output_params = output_sockets['Structure'][ct.FlowKind.Params]
-
- has_output_func = not ct.FlowSignal.check(output_func)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_func and has_output_params and not output_params.symbols:
- return output_func.realize(output_params, disallow_jax=True)
- return ct.FlowSignal.FlowPending
+ value = events.realize_known(output_sockets['Structure'])
+ if value is not None:
+ return value
+ return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
# Loaded
- input_sockets={'Center', 'Radius', 'Medium', 'Height'},
- input_socket_kinds={
- 'Center': ct.FlowKind.Func,
- 'Radius': ct.FlowKind.Func,
- 'Height': ct.FlowKind.Func,
- 'Medium': ct.FlowKind.Func,
+ inscks_kinds={
+ 'Center': FK.Func,
+ 'Radius': FK.Func,
+ 'Height': FK.Func,
+ 'Medium': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Radius': ct.UNITS_TIDY3D,
+ 'Height': ct.UNITS_TIDY3D,
},
)
def compute_func(self, input_sockets) -> td.Box:
@@ -109,119 +114,90 @@ class CylinderStructureNode(base.MaxwellSimNode):
height = input_sockets['Height']
medium = input_sockets['Medium']
- has_center = not ct.FlowSignal.check(center)
- has_radius = not ct.FlowSignal.check(radius)
- has_height = not ct.FlowSignal.check(height)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_center and has_radius and has_height and has_medium:
- return (
- center.scale_to_unit_system(ct.UNITS_TIDY3D)
- | radius.scale_to_unit_system(ct.UNITS_TIDY3D)
- | height.scale_to_unit_system(ct.UNITS_TIDY3D)
- | medium
- ).compose_within(
- lambda els: td.Structure(
- geometry=td.Cylinder(
- center=els[0].flatten().tolist(),
- radius=els[1],
- length=els[2],
- ),
- medium=els[3],
- )
+ return (center | radius | height | medium).compose_within(
+ lambda els: td.Structure(
+ geometry=td.Cylinder(
+ center=els[0].flatten().tolist(),
+ radius=els[1],
+ length=els[2],
+ ),
+ medium=els[3],
)
- return ct.FlowSignal.FlowPending
+ )
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
# Loaded
- input_sockets={'Center', 'Radius', 'Medium', 'Height'},
- input_socket_kinds={
- 'Center': ct.FlowKind.Params,
- 'Radius': ct.FlowKind.Params,
- 'Height': ct.FlowKind.Params,
- 'Medium': ct.FlowKind.Params,
+ inscks_kinds={
+ 'Center': FK.Params,
+ 'Radius': FK.Params,
+ 'Height': FK.Params,
+ 'Medium': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
+ """Aggregate the function parameters needed by the cylinder."""
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
- has_center = not ct.FlowSignal.check(center)
- has_radius = not ct.FlowSignal.check(radius)
- has_height = not ct.FlowSignal.check(height)
- has_medium = not ct.FlowSignal.check(medium)
-
- if has_center and has_radius and has_height and has_medium:
- return center | radius | height | medium
- return ct.FlowSignal.FlowPending
+ return center | radius | height | medium
####################
# - Preview
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
- output_sockets={'Structure'},
- output_socket_kinds={'Structure': ct.FlowKind.Params},
)
- def compute_previews(self, props, output_sockets):
- output_params = output_sockets['Structure']
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if has_output_params and not output_params.symbols:
- return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
- return ct.PreviewsFlow()
+ def compute_previews(self, props):
+ """Mark the cylinder structure as participating in the preview."""
+ return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Radius', 'Medium', 'Height'},
run_on_init=True,
# Loaded
- input_sockets={'Center', 'Radius', 'Medium', 'Height'},
managed_objs={'modifier'},
- output_sockets={'Structure'},
- output_socket_kinds={'Structure': ct.FlowKind.Params},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Radius': {FK.Func, FK.Params},
+ 'Medium': {FK.Func, FK.Params},
+ 'Height': {FK.Func, FK.Params},
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_BLENDER,
+ 'Radius': ct.UNITS_BLENDER,
+ 'Height': ct.UNITS_BLENDER,
+ },
)
- def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
- center = input_sockets['Center']
- radius = input_sockets['Radius']
- height = input_sockets['Height']
- output_params = output_sockets['Structure']
+ def on_previewable_changed(self, managed_objs, input_sockets) -> None:
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ radius = events.realize_preview(input_sockets['Radius'])
+ height = events.realize_preview(input_sockets['Height'])
- has_center = not ct.FlowSignal.check(center)
- has_radius = not ct.FlowSignal.check(radius)
- has_height = not ct.FlowSignal.check(height)
- has_output_params = not ct.FlowSignal.check(output_params)
-
- if (
- has_center
- and has_radius
- and has_height
- and has_output_params
- and not output_params.symbols
- ):
- # Push Loose Input Values to GeoNodes Modifier
- managed_objs['modifier'].bl_modifier(
- 'NODES',
- {
- 'node_group': import_geonodes(GeoNodes.StructurePrimitiveCylinder),
- 'inputs': {
- 'Radius': radius,
- 'Height': height,
- },
- 'unit_system': ct.UNITS_BLENDER,
+ # Push Loose Input Values to GeoNodes Modifier
+ managed_objs['modifier'].bl_modifier(
+ 'NODES',
+ {
+ 'node_group': import_geonodes(GeoNodes.StructurePrimitiveCylinder),
+ 'inputs': {
+ 'Radius': radius,
+ 'Height': height,
},
- location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
- )
+ 'unit_system': ct.UNITS_BLENDER,
+ },
+ location=center,
+ )
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
index 8a51183..de6f102 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/sphere_structure.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `SphereStructureNode`."""
+
import typing as typ
import sympy as sp
@@ -21,8 +23,8 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets
@@ -30,8 +32,14 @@ from ... import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class SphereStructureNode(base.MaxwellSimNode):
+ """A generic sphere structure with configurable size and center."""
+
node_type = ct.NodeType.SphereStructure
bl_label = 'Sphere Structure'
use_sim_node_name = True
@@ -44,7 +52,7 @@ class SphereStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
- default_value=sp.Matrix([0, 0, 0]),
+ default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Radius': sockets.ExprSocketDef(
default_unit=spu.nanometer,
@@ -52,7 +60,7 @@ class SphereStructureNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
- 'Structure': sockets.MaxwellStructureSocketDef(),
+ 'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@@ -60,70 +68,122 @@ class SphereStructureNode(base.MaxwellSimNode):
}
####################
- # - Outputs
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
- input_sockets={'Center', 'Radius', 'Medium'},
- unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
- scale_input_sockets={
- 'Center': 'Tidy3DUnits',
- 'Radius': 'Tidy3DUnits',
+ kind=FK.Value,
+ # Loaded
+ outscks_kinds={
+ 'Structure': {FK.Func, FK.Params},
},
)
- def compute_structure(self, input_sockets, unit_systems) -> td.Box:
- return td.Structure(
- geometry=td.Sphere(
- radius=input_sockets['Radius'],
- center=input_sockets['Center'],
+ def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
+ """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
+ value = events.realize_known(output_sockets['Structure'])
+ if value is not None:
+ return value
+ return FS.FlowPending
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ @events.computes_output_socket(
+ 'Structure',
+ kind=FK.Func,
+ # Loaded
+ inscks_kinds={
+ 'Medium': FK.Func,
+ 'Center': FK.Func,
+ 'Radius': FK.Func,
+ },
+ scale_input_sockets={
+ 'Center': ct.UNITS_TIDY3D,
+ 'Radius': ct.UNITS_TIDY3D,
+ },
+ )
+ def compute_func(self, input_sockets) -> td.Box:
+ """Compute a function, producing a box structure from the input parameters."""
+ center = input_sockets['Center']
+ radius = input_sockets['Radius']
+ medium = input_sockets['Medium']
+
+ return (center | radius | medium).compose_within(
+ lambda els: td.Structure(
+ geometry=td.Sphere(
+ center=els[0].flatten().tolist(),
+ radius=els[1],
+ ),
+ medium=els[2],
),
- medium=input_sockets['Medium'],
)
+ ####################
+ # - FlowKind.Params
+ ####################
+ @events.computes_output_socket(
+ 'Structure',
+ kind=FK.Params,
+ # Loaded
+ inscks_kinds={
+ 'Medium': FK.Params,
+ 'Center': FK.Params,
+ 'Radius': FK.Params,
+ },
+ )
+ def compute_params(self, input_sockets) -> td.Box:
+ """Aggregate the function parameters needed by the sphere."""
+ center = input_sockets['Center']
+ radius = input_sockets['Radius']
+ medium = input_sockets['Medium']
+
+ return center | radius | medium
+
####################
# - Preview
####################
@events.computes_output_socket(
'Structure',
- kind=ct.FlowKind.Previews,
+ kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews(self, props):
+ """Mark the sphere structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
- socket_name={'Center', 'Radius'},
+ socket_name={
+ 'Center': {FK.Func, FK.Params},
+ 'Size': {FK.Func, FK.Params},
+ },
run_on_init=True,
# Loaded
- input_sockets={'Center', 'Radius'},
managed_objs={'modifier'},
- unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
+ inscks_kinds={
+ 'Center': {FK.Func, FK.Params},
+ 'Radius': {FK.Func, FK.Params},
+ },
scale_input_sockets={
- 'Center': 'BlenderUnits',
+ 'Center': ct.UNITS_BLENDER,
+ 'Radius': ct.UNITS_BLENDER,
},
)
- def on_inputs_changed(
- self,
- managed_objs,
- input_sockets,
- unit_systems,
- ):
- modifier = managed_objs['modifier']
- unit_system = unit_systems['BlenderUnits']
+ def on_previewable_changed(self, managed_objs, input_sockets):
+ """Push changes in the inputs to the center / size."""
+ center = events.realize_preview(input_sockets['Center'])
+ radius = events.realize_preview(input_sockets['Radius'])
- # Push Loose Input Values to GeoNodes Modifier
- modifier.bl_modifier(
+ managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveSphere),
'inputs': {
- 'Radius': input_sockets['Radius'],
+ 'Radius': radius,
},
- 'unit_system': unit_system,
},
- location=input_sockets['Center'],
+ location=center,
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/__init__.py
index 39c4b01..563e9d7 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/__init__.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/__init__.py
@@ -14,18 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-# from . import math
-from . import combine
-
-# from . import separate
+from . import combine, view_text, wave_constant
BL_REGISTER = [
- # *math.BL_REGISTER,
+ *wave_constant.BL_REGISTER,
*combine.BL_REGISTER,
- # *separate.BL_REGISTER,
+ *view_text.BL_REGISTER,
]
BL_NODES = {
- # **math.BL_NODES,
+ **wave_constant.BL_NODES,
**combine.BL_NODES,
- # **separate.BL_NODES,
+ **view_text.BL_NODES,
}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/combine.py
similarity index 75%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/combine.py
index 29e9138..d2b8fcc 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/combine.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/combine.py
@@ -20,12 +20,17 @@ import typing as typ
import bpy
import sympy as sp
-from blender_maxwell.utils import bl_cache
+from blender_maxwell.utils import bl_cache, logger
from ... import contracts as ct
from ... import sockets
from .. import base, events
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
class CombineNode(base.MaxwellSimNode):
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
@@ -44,17 +49,17 @@ class CombineNode(base.MaxwellSimNode):
output_socket_sets: typ.ClassVar = {
'Sources': {
'Sources': sockets.MaxwellSourceSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
},
'Structures': {
'Structures': sockets.MaxwellStructureSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
},
'Monitors': {
'Monitors': sockets.MaxwellMonitorSocketDef(
- active_kind=ct.FlowKind.Array,
+ active_kind=FK.Array,
),
},
}
@@ -63,14 +68,14 @@ class CombineNode(base.MaxwellSimNode):
# - Properties
####################
concatenate_first: bool = bl_cache.BLField(False)
- value_or_func: ct.FlowKind = bl_cache.BLField(
+ value_or_func: FK = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_func(),
)
def _value_or_func(self):
return [
flow_kind.bl_enum_element(i)
- for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
+ for i, flow_kind in enumerate([FK.Value, FK.Func])
]
####################
@@ -79,7 +84,7 @@ class CombineNode(base.MaxwellSimNode):
def draw_props(self, _, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['value_or_func'], text='')
- if self.value_or_func is ct.FlowKind.Value:
+ if self.value_or_func is FK.Value:
layout.prop(
self,
self.blfields['concatenate_first'],
@@ -91,15 +96,17 @@ class CombineNode(base.MaxwellSimNode):
# - Events
####################
@events.on_value_changed(
- any_loose_input_socket=True,
prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
+ any_loose_input_socket=True,
run_on_init=True,
# Loaded
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
)
def on_inputs_changed(self, props) -> None:
- """Always create one extra loose input socket."""
+ """Always create one extra loose input socket off the end of the last linked loose socket."""
active_socket_set = props['active_socket_set']
+ concatenate_first = props['concatenate_first']
+ flow_kind = props['value_or_func']
# Deduce SocketDef
## -> Cheat by retrieving the class from the output sockets.
@@ -121,14 +128,11 @@ class CombineNode(base.MaxwellSimNode):
new_amount = current_filled + 1
# Deduce SocketDef | Current Amount
- concatenate_first = props['concatenate_first']
- flow_kind = props['value_or_func']
-
self.loose_input_sockets = {
'#0': SocketDef(
active_kind=flow_kind
- if flow_kind is ct.FlowKind.Func or not concatenate_first
- else ct.FlowKind.Array
+ if flow_kind is FK.Func or not concatenate_first
+ else FK.Array
)
} | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
@@ -138,29 +142,25 @@ class CombineNode(base.MaxwellSimNode):
def compute_combined(
self,
loose_input_sockets,
- input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
- output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
- ) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
+ input_flow_kind: typ.Literal[FK.Value, FK.Func],
+ output_flow_kind: typ.Literal[FK.Array, FK.Func],
+ ) -> list[typ.Any] | ct.FuncFlow | FS:
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
If there is no output, or the flows aren't compatible, return `FlowPending`.
"""
match (input_flow_kind, output_flow_kind):
- case (ct.FlowKind.Value, ct.FlowKind.Array):
+ case (FK.Value, FK.Array):
value_flows = [
- inp
- for inp in loose_input_sockets.values()
- if not ct.FlowSignal.check(inp)
+ inp for inp in loose_input_sockets.values() if not FS.check(inp)
]
if value_flows:
return value_flows
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
- case (ct.FlowKind.Func, ct.FlowKind.Func):
+ case (FK.Func, FK.Func):
func_flows = [
- inp
- for inp in loose_input_sockets.values()
- if not ct.FlowSignal.check(inp)
+ inp for inp in loose_input_sockets.values() if not FS.check(inp)
]
if len(func_flows) > 1:
@@ -171,63 +171,59 @@ class CombineNode(base.MaxwellSimNode):
if len(func_flows) == 1:
return func_flows[0].compose_within(lambda el: [el])
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
- case (ct.FlowKind.Func, ct.FlowKind.Params):
+ case (FK.Func, FK.Params):
params_flows = [
params_flow
for inp_sckname in self.inputs.keys() # noqa: SIM118
- if not ct.FlowSignal.check(
- params_flow := self._compute_input(
- inp_sckname, kind=ct.FlowKind.Params
- )
+ if not FS.check(
+ params_flow := self._compute_input(inp_sckname, kind=FK.Params)
)
]
if params_flows:
return functools.reduce(lambda a, b: a | b, params_flows)
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
- return ct.FlowSignal.FlowPending
+ return FS.FlowPending
####################
# - Output: Sources
####################
@events.computes_output_socket(
'Sources',
- kind=ct.FlowKind.Array,
+ kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
- def compute_sources_array(
- self, props, loose_input_sockets
- ) -> list[typ.Any] | ct.FlowSignal:
+ def compute_sources_array(self, props, loose_input_sockets) -> list[typ.Any] | FS:
"""Compute sources."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
+ loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Sources',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
+ loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Sources',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
+ loose_input_sockets, props['value_or_func'], FK.Params
)
####################
@@ -235,38 +231,38 @@ class CombineNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Structures',
- kind=ct.FlowKind.Array,
+ kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute structures."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
+ loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Structures',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
+ loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Structures',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
+ loose_input_sockets, props['value_or_func'], FK.Params
)
####################
@@ -274,38 +270,38 @@ class CombineNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Monitors',
- kind=ct.FlowKind.Array,
+ kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute monitors."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
+ loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Monitors',
- kind=ct.FlowKind.Func,
+ kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) monitors."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
+ loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Monitors',
- kind=ct.FlowKind.Params,
+ kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
- loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
+ loose_input_sockets, props['value_or_func'], FK.Params
)
@@ -315,4 +311,4 @@ class CombineNode(base.MaxwellSimNode):
BL_REGISTER = [
CombineNode,
]
-BL_NODES = {ct.NodeType.Combine: (ct.NodeCategory.MAXWELLSIM_SIMS)}
+BL_NODES = {ct.NodeType.Combine: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/view_text.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/view_text.py
new file mode 100644
index 0000000..b38b578
--- /dev/null
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/view_text.py
@@ -0,0 +1,127 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import typing as typ
+
+import bpy
+import sympy as sp
+import tidy3d as td
+
+from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from ... import contracts as ct
+from ... import sockets
+from .. import base, events
+
+log = logger.get(__name__)
+
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
+
+class ConsoleViewOperator(bpy.types.Operator):
+ bl_idname = 'blender_maxwell.console_view_operator'
+ bl_label = 'View Plots'
+
+ @classmethod
+ def poll(cls, _: bpy.types.Context):
+ return True
+
+ def execute(self, context):
+ node = context.node
+
+ node.print_data_to_console()
+ return {'FINISHED'}
+
+
+class RefreshPlotViewOperator(bpy.types.Operator):
+ bl_idname = 'blender_maxwell.refresh_plot_view_operator'
+ bl_label = 'Refresh Plots'
+
+ @classmethod
+ def poll(cls, _: bpy.types.Context):
+ return True
+
+ def execute(self, context):
+ node = context.node
+ node.on_changed_plot_preview()
+ return {'FINISHED'}
+
+
+####################
+# - Node
+####################
+class ViewTextNode(base.MaxwellSimNode):
+ node_type = ct.NodeType.ViewText
+ bl_label = 'View Text'
+ # use_sim_node_name = True
+
+ input_sockets: typ.ClassVar = {
+ 'Text': sockets.StringSocketDef(),
+ }
+
+ ####################
+ # - Properties
+ ####################
+ push_live: bool = bl_cache.BLField(True)
+
+ ####################
+ # - Properties: Computed FlowKinds
+ ####################
+ @events.on_value_changed(
+ socket_name={'Text': FK.Value},
+ prop_name={'push_live', 'sim_node_name'},
+ # Loaded
+ inscks_kinds={'Text': FK.Value},
+ input_sockets_optional={'Text'},
+ props={'push_live', 'sim_node_name'},
+ )
+ def on_text_changed(self, props, input_sockets) -> None:
+ sim_node_name = props['sim_node_name']
+ push_live = props['push_live']
+
+ if push_live:
+ if bpy.data.texts.get(sim_node_name) is None:
+ bpy.data.texts.new(sim_node_name)
+ bl_text = bpy.data.texts[sim_node_name]
+
+ bl_text.clear()
+
+ text = input_sockets['Text']
+ has_text = not FS.check(text)
+ if has_text:
+ bl_text.write(text)
+
+ ####################
+ # - UI
+ ####################
+ def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ row = layout.row(align=True)
+
+ row.prop(self, self.blfields['push_live'], text='Write Live', toggle=True)
+
+ def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
+ pass
+
+
+####################
+# - Blender Registration
+####################
+BL_REGISTER = [
+ ViewTextNode,
+]
+BL_NODES = {ct.NodeType.ViewText: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/wave_constant.py
similarity index 57%
rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py
rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/wave_constant.py
index 55f96c6..33e312c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/utilities/wave_constant.py
@@ -22,7 +22,7 @@ import bpy
import sympy as sp
import sympy.physics.units as spu
-from blender_maxwell.utils import bl_cache, logger, sci_constants
+from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
@@ -31,9 +31,13 @@ from .. import base, events
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+MT = spux.MathType
+
class WaveConstantNode(base.MaxwellSimNode):
- """Translates vaccum wavelength/frequency into both, either as a scalar, or as a memory-efficient uniform range of values.
+ """Translates vacuum wavelength / non-angular frequency into both, either as a scalar, or as a memory-efficient uniform range of values.
Socket Sets:
Wavelength: Input a wavelength (range) to produce both wavelength/frequency (ranges).
@@ -100,7 +104,7 @@ class WaveConstantNode(base.MaxwellSimNode):
run_on_init=True,
)
def on_use_range_changed(self, props: dict) -> None:
- """Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.Range`."""
+ """Synchronize the `active_kind` of input/output sockets, to either produce a `FK.Value` or a `FK.Range`."""
if self.inputs.get('WL') is not None:
active_input = self.inputs['WL']
else:
@@ -108,46 +112,46 @@ class WaveConstantNode(base.MaxwellSimNode):
# Modify Active Kind(s)
## Input active_kind -> Value/Range
- active_input_uses_range = active_input.active_kind == ct.FlowKind.Range
+ active_input_uses_range = active_input.active_kind == FK.Range
if active_input_uses_range != props['use_range']:
- active_input.active_kind = (
- ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
- )
+ active_input.active_kind = FK.Range if props['use_range'] else FK.Value
## Output active_kind -> Value/Range
for active_output in self.outputs.values():
- active_output_uses_range = active_output.active_kind == ct.FlowKind.Range
+ active_output_uses_range = active_output.active_kind == FK.Range
if active_output_uses_range != props['use_range']:
- active_output.active_kind = (
- ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
- )
+ active_output.active_kind = FK.Range if props['use_range'] else FK.Value
####################
- # - FlowKinds
+ # - FlowKind.Value
####################
@events.computes_output_socket(
'WL',
- kind=ct.FlowKind.Value,
- input_sockets={'WL', 'Freq'},
- input_sockets_optional={'WL': True, 'Freq': True},
+ kind=FK.Value,
+ # Loaded
+ inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
+ input_sockets_optional={'WL', 'Freq'},
)
def compute_wl_value(self, input_sockets: dict) -> sp.Expr:
"""Compute a single wavelength value from either wavelength/frequency."""
- has_wl = not ct.FlowSignal.check(input_sockets['WL'])
+ has_wl = not FS.check(input_sockets['WL'])
if has_wl:
return input_sockets['WL']
- return sci_constants.vac_speed_of_light / input_sockets['Freq']
+ return spu.convert_to(
+ sci_constants.vac_speed_of_light / input_sockets['Freq'], spu.um
+ )
@events.computes_output_socket(
'Freq',
- kind=ct.FlowKind.Value,
- input_sockets={'WL', 'Freq'},
- input_sockets_optional={'WL': True, 'Freq': True},
+ kind=FK.Value,
+ # Loaded
+ inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
+ input_sockets_optional={'WL', 'Freq'},
)
def compute_freq_value(self, input_sockets: dict) -> sp.Expr:
"""Compute a single frequency value from either wavelength/frequency."""
- has_freq = not ct.FlowSignal.check(input_sockets['Freq'])
+ has_freq = not FS.check(input_sockets['Freq'])
if has_freq:
return input_sockets['Freq']
@@ -155,62 +159,117 @@ class WaveConstantNode(base.MaxwellSimNode):
sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz
)
+ ####################
+ # - FlowKind.Range
+ ####################
@events.computes_output_socket(
'WL',
- kind=ct.FlowKind.Range,
- input_sockets={'WL', 'Freq'},
- input_socket_kinds={
- 'WL': ct.FlowKind.Range,
- 'Freq': ct.FlowKind.Range,
- },
- input_sockets_optional={'WL': True, 'Freq': True},
+ kind=FK.Range,
+ # Loaded
+ inscks_kinds={'WL': FK.Range, 'Freq': FK.Range},
+ input_sockets_optional={'WL', 'Freq'},
)
- def compute_wl_range(self, input_sockets: dict) -> sp.Expr:
+ def compute_wl_range(self, input_sockets) -> sp.Expr:
"""Compute wavelength range from either wavelength/frequency ranges."""
- has_wl = not ct.FlowSignal.check(input_sockets['WL'])
+ has_wl = not FS.check(input_sockets['WL'])
if has_wl:
return input_sockets['WL']
freq = input_sockets['Freq']
- return ct.RangeFlow(
- start=spux.scale_to_unit(
- sci_constants.vac_speed_of_light / (freq.stop * freq.unit), spu.um
- ),
- stop=spux.scale_to_unit(
- sci_constants.vac_speed_of_light / (freq.start * freq.unit), spu.um
- ),
- steps=freq.steps,
- scaling=freq.scaling,
- unit=spu.um,
+ return freq.rescale(
+ lambda bound: sci_constants.vac_speed_of_light / bound,
+ reverse=True,
+ new_unit=spu.um,
)
@events.computes_output_socket(
'Freq',
- kind=ct.FlowKind.Range,
+ kind=FK.Range,
input_sockets={'WL', 'Freq'},
input_socket_kinds={
- 'WL': ct.FlowKind.Range,
- 'Freq': ct.FlowKind.Range,
+ 'WL': FK.Range,
+ 'Freq': FK.Range,
},
input_sockets_optional={'WL': True, 'Freq': True},
)
def compute_freq_range(self, input_sockets: dict) -> sp.Expr:
"""Compute frequency range from either wavelength/frequency ranges."""
- has_freq = not ct.FlowSignal.check(input_sockets['Freq'])
+ has_freq = not FS.check(input_sockets['Freq'])
if has_freq:
return input_sockets['Freq']
wl = input_sockets['WL']
- return ct.RangeFlow(
- start=spux.scale_to_unit(
- sci_constants.vac_speed_of_light / (wl.stop * wl.unit), spux.THz
- ),
- stop=spux.scale_to_unit(
- sci_constants.vac_speed_of_light / (wl.start * wl.unit), spux.THz
- ),
- steps=wl.steps,
- scaling=wl.scaling,
- unit=spux.THz,
+ return wl.rescale(
+ lambda bound: sci_constants.vac_speed_of_light / bound,
+ reverse=True,
+ new_unit=spux.THz,
+ )
+
+ ####################
+ # - FlowKind.Func
+ ####################
+ # @events.computes_output_socket(
+ # 'WL',
+ # kind=FK.Func,
+ # # Loaded
+ # inscks_kinds={'WL': FK.Func, 'Freq': FK.Func},
+ # input_sockets_optional={'WL', 'Freq'},
+ # )
+ # def compute_wl_func(self, input_sockets: dict) -> ct.FuncFlow | FS:
+ # """Compute a single wavelength value from either wavelength/frequency."""
+ # wl = input_sockets['WL']
+ # has_wl = not FS.check(wl)
+ # if has_wl:
+ # return wl
+
+ # freq = input_sockets['Freq']
+ # has_freq = not FS.check(freq)
+ # if has_freq:
+ # return wl.compose_within(
+ # return spu.convert_to(
+ # sci_constants.vac_speed_of_light / input_sockets['Freq'], spu.um
+ # )
+
+ # return FS.FlowPending
+
+ # @events.computes_output_socket(
+ # 'Freq',
+ # kind=FK.Value,
+ # # Loaded
+ # inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
+ # input_sockets_optional={'WL', 'Freq'},
+ # )
+ # def compute_freq_value(self, input_sockets: dict) -> sp.Expr:
+ # """Compute a single frequency value from either wavelength/frequency."""
+ # has_freq = not FS.check(input_sockets['Freq'])
+ # if has_freq:
+ # return input_sockets['Freq']
+
+ # return spu.convert_to(
+ # sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz
+ # )
+
+ ####################
+ # - FlowKind.Info
+ ####################
+ @events.computes_output_socket(
+ 'WL',
+ kind=FK.Info,
+ )
+ def compute_wl_info(self) -> ct.InfoFlow:
+ """Just enough InfoFlow to enable `linked_capabilities`."""
+ return ct.InfoFlow(
+ output=sim_symbols.wl(spu.um),
+ )
+
+ @events.computes_output_socket(
+ 'Freq',
+ kind=FK.Info,
+ )
+ def compute_freq_info(self) -> sp.Expr:
+ """Compute frequency range from either wavelength/frequency ranges."""
+ return ct.InfoFlow(
+ output=sim_symbols.freq(spux.THz),
)
@@ -220,4 +279,4 @@ class WaveConstantNode(base.MaxwellSimNode):
BL_REGISTER = [
WaveConstantNode,
]
-BL_NODES = {ct.NodeType.WaveConstant: (ct.NodeCategory.MAXWELLSIM_INPUTS)}
+BL_NODES = {ct.NodeType.WaveConstant: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py
index 533d0ba..ba996be 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py
@@ -26,6 +26,9 @@ from .. import contracts as ct
log = logger.get(__name__)
+FK = ct.FlowKind
+FS = ct.FlowSignal
+
####################
# - SocketDef
@@ -171,6 +174,7 @@ class SocketDef(pyd.BaseModel, abc.ABC):
####################
# - Socket
####################
+FLOW_ERROR_COLOR: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 1.0)
MANDATORY_PROPS: set[str] = {'socket_type', 'bl_label'}
@@ -211,6 +215,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
(0, 0, 0, 0), use_prop_update=False
)
+ flow_error: bool = bl_cache.BLField(False, use_prop_update=False)
+
####################
# - Initialization
####################
@@ -352,7 +358,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
## -> The tradeoff: No link if there is no InfoFlow.
if self.use_linked_capabilities:
info = self.compute_data(kind=ct.FlowKind.Info)
- has_info = not ct.FlowSignal.check(info)
+ has_info = not FS.check(info)
if has_info:
incoming_capabilities = link.from_socket.linked_capabilities(info)
else:
@@ -477,30 +483,10 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Run Socket Callbacks
self.on_socket_data_changed(socket_kinds)
- # Mark Active FlowKind Links as Invalid
- ## -> Mark link as invalid (very red) if a FlowSignal is traveling.
- ## -> This helps explain why whatever isn't working isn't working.
- ## -> TODO: We need a different approach.
- # log.debug(
- # '[%s] Checking FlowKind Validity (socket_kinds=%s)',
- # self.name,
- # str(socket_kinds),
- # )
- # if self.is_linked and not self.is_output:
- # link = self.links[0]
- # linked_flow = self.compute_data(kind=self.active_kind)
-
- # if (
- # link.is_valid
- # and self.active_kind in socket_kinds
- # and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
- # ):
- # node_tree = self.id_data
- # node_tree.report_link_validity(link, False)
-
- # elif not link.is_valid:
- # node_tree = self.id_data
- # node_tree.report_link_validity(link, True)
+ # Clear FlowErrors
+ ## -> We should presume by default that the updated value is OK.
+ if self.flow_error:
+ bpy.app.timers.register(self.clear_flow_error)
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
@@ -646,7 +632,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Returns:
An empty `ct.InfoFlow`.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
# Param
@property
@@ -659,7 +645,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Returns:
An empty `ct.ParamsFlow`.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
####################
# - FlowKind: Auxiliary
@@ -675,7 +661,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
@value.setter
def value(self, value: ct.ValueFlow) -> None:
@@ -701,7 +687,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
@array.setter
def array(self, value: ct.ArrayFlow) -> None:
@@ -727,7 +713,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
@lazy_func.setter
def lazy_func(self, lazy_func: ct.FuncFlow) -> None:
@@ -753,7 +739,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
- return ct.FlowSignal.NoFlow
+ return FS.NoFlow
@lazy_range.setter
def lazy_range(self, value: ct.RangeFlow) -> None:
@@ -818,29 +804,42 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
"""
# Compute Output Socket
if self.is_output:
- return self.node.compute_output(self.name, kind=kind)
+ flow = self.node.compute_output(self.name, kind=kind)
# Compute Input Socket
## -> Unlinked: Retrieve Socket Value
- if not self.is_linked:
- return self._compute_data(kind)
+ elif not self.is_linked:
+ flow = self._compute_data(kind)
- ## Linked: Compute Data on Linked Socket
- ## -> Capabilities are guaranteed compatible by 'allow_link_add'.
- ## -> There is no point in rechecking every time data flows.
- linked_values = [link.from_socket.compute_data(kind) for link in self.links]
+ else:
+ # Linked: Compute Data on Linked Socket
+ ## -> Capabilities are guaranteed compatible by 'allow_link_add'.
+ ## -> There is no point in rechecking every time data flows.
+ linked_values = [link.from_socket.compute_data(kind) for link in self.links]
- # Return Single Value / List of Values
- ## -> Multi-input sockets are not (yet) supported.
- if linked_values:
- return linked_values[0]
+ # Return Single Value / List of Values
+ ## -> Multi-input sockets are not (yet) supported.
+ if linked_values: # noqa: SIM108
+ flow = linked_values[0]
- # Edge Case: While Dragging Link (but not yet removed)
- ## While the user is dragging a link:
- ## - self.is_linked = True, since the user hasn't confirmed anything.
- ## - self.links will be empty, since the link object was freed.
- ## When this particular condition is met, pretend that we're not linked.
- return self._compute_data(kind)
+ # Edge Case: While Dragging Link (but not yet removed)
+ ## While the user is dragging a link:
+ ## - self.is_linked = True, since the user hasn't confirmed anything.
+ ## - self.links will be empty, since the link object was freed.
+ ## When this particular condition is met, pretend that we're not linked.
+ else:
+ flow = self._compute_data(kind)
+
+ if FS.check_single(flow, FS.FlowPending) and not self.flow_error:
+ bpy.app.timers.register(self.declare_flow_error)
+
+ return flow
+
+ def declare_flow_error(self):
+ self.flow_error = True
+
+ def clear_flow_error(self):
+ self.flow_error = False
####################
# - UI - Color
@@ -858,6 +857,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Notes:
Called by Blender to call the socket color.
"""
+ if self.flow_error:
+ return FLOW_ERROR_COLOR
if self.use_socket_color:
return self.socket_color
return ct.SOCKET_COLORS[self.socket_type]
@@ -978,7 +979,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Info Drawing
if self.use_info_draw:
info = self.compute_data(kind=ct.FlowKind.Info)
- if not ct.FlowSignal.check(info):
+ if not FS.check(info):
self.draw_info(info, col)
def draw_output(
@@ -1009,7 +1010,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Draw FlowKind.Info related Information
if self.use_info_draw:
info = self.compute_data(kind=ct.FlowKind.Info)
- if not ct.FlowSignal.check(info):
+ if not FS.check(info):
self.draw_info(info, col)
####################
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py
index b198935..cb30675 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/basic/string.py
@@ -38,7 +38,7 @@ class StringBLSocket(base.MaxwellSimSocket):
# - Socket UI
####################
def draw_label_row(self, label_col_row: bpy.types.UILayout, text: str) -> None:
- label_col_row.prop(self, 'raw_value', text=text)
+ label_col_row.prop(self, self.blfields['raw_value'], text=text)
####################
# - Computation of Default Value
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
index 6c553a7..eb63506 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py
@@ -25,11 +25,17 @@ import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import frozendict
from .. import contracts as ct
from . import base
log = logger.get(__name__)
+FK = ct.FlowKind
+MT = spux.MathType
+
+UI_FLOAT_EPS = sp.Float(0.0001, 1)
+UI_FLOAT_PREC = 4
Int2: typ.TypeAlias = tuple[int, int]
Int3: typ.TypeAlias = tuple[int, int, int]
@@ -44,19 +50,21 @@ Float32: typ.TypeAlias = tuple[
####################
-# - Utilitives
+# - Utilities
####################
def unicode_superscript(n: int) -> str:
"""Transform an integer into its unicode-based superscript character."""
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
-def _check_sym_oo(sym):
- return sym.is_real or sym.is_rational or sym.is_integer
-
-
class InfoDisplayCol(enum.StrEnum):
- """Valid columns for specifying displayed information from an `ct.InfoFlow`."""
+ """Valid columns for specifying displayed information from an `ct.InfoFlow`.
+
+ Attributes:
+ Length: Display the size of the dimensional index.
+ MathType: Display the `MT` of the dimensional symbol.
+ Unit: Display the unit of the dimensional symbol.
+ """
Length = enum.auto()
MathType = enum.auto()
@@ -112,25 +120,33 @@ class ExprBLSocket(base.MaxwellSimSocket):
use_socket_color = True
####################
- # - Socket Interface
+ # - Identifier
####################
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
- mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
+ mathtype: MT = bl_cache.BLField(MT.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
+ ####################
+ # - Output Symbol
+ ####################
@bl_cache.cached_bl_property(
+ ## -> CAREFUL: 'output_sym' changes recompiles `FuncFlow`.
depends_on={
- 'active_kind',
- 'symbols',
- 'raw_value_spstr',
- 'raw_min_spstr',
- 'raw_max_spstr',
+ # Identity
'output_name',
+ 'active_kind',
'mathtype',
'physical_type',
'unit',
'size',
- 'value',
+ # Symbols / Symbolic Expression
+ 'symbols',
+ 'raw_value_spstr',
+ 'raw_min_spstr',
+ 'raw_max_spstr',
+ # Domain
+ 'domain',
+ 'steps', ## -> Func needs to recompile anyway if steps changes.
}
)
def output_sym(self) -> sim_symbols.SimSymbol | None:
@@ -142,12 +158,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
"""
match self.active_kind:
- case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
+ case FK.Value | FK.Func if self.symbols:
return self._parse_expr_symbol(
self._parse_expr_str(self.raw_value_spstr)
)
- case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
+ case FK.Value | FK.Func if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
@@ -155,17 +171,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
- is_constant=True,
- ## TODO: Should we set preview values
- exclude_zero=(
- not self.value.is_zero
- if self.value.is_zero is not None
- else False
- ),
- ## TODO: Does this 0-check work for matrix elements?
+ domain=self.domain,
)
- case ct.FlowKind.Range if self.symbols:
+ case FK.Range if self.symbols:
## TODO: Support RangeFlow
## -- It's hard; we need a min-span set over bound domains.
## -- We... Don't use this anywhere. Yet?
@@ -178,29 +187,181 @@ class ExprBLSocket(base.MaxwellSimSocket):
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
raise NotImplementedError(msg)
- case ct.FlowKind.Range if not self.symbols:
+ case FK.Range if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
- rows=self.lazy_range.steps,
+ rows=self.steps,
cols=1,
- exclude_zero=not self.lazy_range.is_always_nonzero,
+ domain=self.domain,
)
+ ####################
+ # - Domain
+ ####################
+ exclude_zero: bool = bl_cache.BLField(True)
+
+ abs_min_infinite: bool = bl_cache.BLField(True)
+ abs_max_infinite: bool = bl_cache.BLField(True)
+ abs_min_infinite_im: bool = bl_cache.BLField(True)
+ abs_max_infinite_im: bool = bl_cache.BLField(True)
+
+ abs_min_closed: bool = bl_cache.BLField(True)
+ abs_max_closed: bool = bl_cache.BLField(True)
+ abs_min_closed_im: bool = bl_cache.BLField(True)
+ abs_max_closed_im: bool = bl_cache.BLField(True)
+
+ abs_min_int: int = bl_cache.BLField(0)
+ abs_min_rat: Int2 = bl_cache.BLField((0, 1))
+ abs_min_float: float = bl_cache.BLField(0.0, float_prec=UI_FLOAT_PREC)
+ abs_min_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=UI_FLOAT_PREC)
+
+ abs_max_int: int = bl_cache.BLField(0)
+ abs_max_rat: Int2 = bl_cache.BLField((0, 1))
+ abs_max_float: float = bl_cache.BLField(0.0, float_prec=UI_FLOAT_PREC)
+ abs_max_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=UI_FLOAT_PREC)
+
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'mathtype',
+ 'abs_min_infinite',
+ 'abs_min_infinite_im',
+ 'abs_min_int',
+ 'abs_min_rat',
+ 'abs_min_float',
+ 'abs_min_complex',
+ }
+ )
+ def abs_inf(self) -> sp.Integer | sp.Rational | sp.Float | spux.ComplexNumber:
+ """Deduce the infimum of values expressable by this socket."""
+ match self.mathtype:
+ case MT.Integer | MT.Rational | MT.Real if self.abs_min_infinite:
+ return -sp.oo
+ case MT.Integer:
+ abs_min = sp.Integer(self.abs_min_int)
+ case MT.Rational:
+ abs_min = sp.Rational(*self.abs_min_rat)
+ case MT.Real:
+ abs_min = sp.Float(self.abs_min_float, UI_FLOAT_PREC)
+ case MT.Complex:
+ cplx = self.abs_min_complex
+ abs_min_re = (
+ sp.Float(cplx[0], UI_FLOAT_PREC)
+ if not self.abs_min_infinite
+ else -sp.oo
+ )
+ abs_min_im = (
+ sp.Float(cplx[1], UI_FLOAT_PREC)
+ if not self.abs_min_infinite_im
+ else -sp.oo
+ )
+ abs_min = abs_min_re + sp.I * abs_min_im
+
+ return abs_min
+
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'mathtype',
+ 'abs_max_infinite',
+ 'abs_max_infinite_im',
+ 'abs_max_int',
+ 'abs_max_rat',
+ 'abs_max_float',
+ 'abs_max_complex',
+ }
+ )
+ def abs_sup(self) -> sp.Integer | sp.Rational | sp.Float | spux.ComplexNumber:
+ """Deduce the infimum of values expressable by this socket."""
+ match self.mathtype:
+ case MT.Integer | MT.Rational | MT.Real if self.abs_max_infinite:
+ return sp.oo
+ case MT.Integer:
+ abs_max = sp.Integer(self.abs_max_int)
+ case MT.Rational:
+ abs_max = sp.Rational(*self.abs_max_rat)
+ case MT.Real:
+ abs_max = sp.Float(self.abs_max_float, UI_FLOAT_PREC)
+ case MT.Complex:
+ cplx = self.abs_max_complex
+ abs_max_re = (
+ sp.Float(cplx[0], UI_FLOAT_PREC)
+ if not self.abs_max_infinite
+ else sp.oo
+ )
+ abs_max_im = (
+ sp.Float(cplx[1], UI_FLOAT_PREC)
+ if not self.abs_max_infinite_im
+ else sp.oo
+ )
+ abs_max = abs_max_re + sp.I * abs_max_im
+
+ return abs_max
+
+ @bl_cache.cached_bl_property(
+ depends_on={
+ 'abs_inf',
+ 'abs_sup',
+ 'exclude_zero',
+ 'abs_min_closed',
+ 'abs_max_closed',
+ 'abs_min_closed_im',
+ 'abs_max_closed_im',
+ }
+ )
+ def domain(self) -> spux.BlessedSet:
+ """Deduce the domain of the socket's expression."""
+ match self.mathtype:
+ case MT.Integer:
+ domain = spux.BlessedSet(
+ sp.Range(
+ self.abs_inf if self.abs_min_closed else self.abs_inf + 1,
+ self.abs_sup + 1 if self.abs_max_closed else self.abs_sup,
+ )
+ )
+ case MT.Rational | MT.Real:
+ domain = spux.BlessedSet(
+ sp.Interval(
+ self.abs_inf,
+ self.abs_sup,
+ left_open=not self.abs_min_closed,
+ right_open=not self.abs_max_closed,
+ )
+ )
+ case MT.Complex:
+ domain = spux.BlessedSet.reals_to_complex(
+ sp.Interval(
+ sp.re(self.abs_inf),
+ sp.re(self.abs_sup),
+ left_open=not self.abs_min_closed,
+ right_open=not self.abs_max_closed,
+ ),
+ sp.Interval(
+ sp.im(self.abs_inf),
+ sp.im(self.abs_sup),
+ left_open=not self.abs_min_closed_im,
+ right_open=not self.abs_max_closed_im,
+ ),
+ )
+
+ if self.exclude_zero:
+ return domain - sp.FiniteSet(0)
+ return domain
+
####################
# - Value|Range Swapper
####################
use_value_range_swapper: bool = bl_cache.BLField(False)
- selected_value_range: ct.FlowKind = bl_cache.BLField(
- enum_cb=lambda self, _: self._value_or_range(),
+ selected_value_range: FK = bl_cache.BLField(
+ enum_cb=lambda self, _: self.search_value_or_range(),
)
- def _value_or_range(self):
+ def search_value_or_range(self):
+ """Either `FlowKind.Value` or `FlowKind.Range`."""
return [
flow_kind.bl_enum_element(i)
- for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
+ for i, flow_kind in enumerate([FK.Value, FK.Range])
]
####################
@@ -369,7 +530,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
# - Event Callbacks
####################
- def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
+ def on_socket_data_changed(self, socket_kinds: set[FK]) -> None:
"""Alter the socket's color in response to flow.
- `FlowKind.Info`: Any change causes the socket color to be updated with the physical type of the output symbol.
@@ -381,8 +542,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
## NOTE: Depends on suppressed on_prop_changed
- if ct.FlowKind.Info in socket_kinds:
- info = self.compute_data(kind=ct.FlowKind.Info)
+ if FK.Info in socket_kinds:
+ info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
# Alter Color
@@ -443,7 +604,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False):
"""Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`."""
- pyvalue = spux.scale_to_unit(expr, self.unit)
+ pyvalue = spux.scale_to_unit(expr, self.unit, cast_to_pytype=True)
# Cast complex -> tuple[float, float]
## -> We can't set complex to BLProps.
@@ -452,6 +613,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
isinstance(pyvalue, int | float) and force_complex
):
return (pyvalue.real, pyvalue.imag)
+
if isinstance(pyvalue, tuple) and all(
isinstance(v, complex)
or (isinstance(pyvalue, int | float) and force_complex)
@@ -557,6 +719,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'unit',
'mathtype',
'size',
+ 'domain',
'raw_value_sp',
'raw_value_int',
'raw_value_rat',
@@ -601,20 +764,21 @@ class ExprBLSocket(base.MaxwellSimSocket):
if self.size is NS.Vec4:
return ct.FlowSignal.NoFlow
- MT_Z = spux.MathType.Integer
- MT_Q = spux.MathType.Rational
- MT_R = spux.MathType.Real
- MT_C = spux.MathType.Complex
+ MT_Z = MT.Integer
+ MT_Q = MT.Rational
+ MT_R = MT.Real
+ MT_C = MT.Complex
Z = sp.Integer
Q = sp.Rational
- R = sp.RealNumber
- return {
+ R = sp.Float
+ raw_value = {
NS.Scalar: {
MT_Z: lambda: Z(self.raw_value_int),
MT_Q: lambda: Q(self.raw_value_rat[0], self.raw_value_rat[1]),
- MT_R: lambda: R(self.raw_value_float),
+ MT_R: lambda: R(self.raw_value_float, UI_FLOAT_PREC),
MT_C: lambda: (
- self.raw_value_complex[0] + sp.I * self.raw_value_complex[1]
+ R(self.raw_value_complex[0], UI_FLOAT_PREC)
+ + sp.I * R(self.raw_value_complex[1], UI_FLOAT_PREC)
),
},
NS.Vec2: {
@@ -622,9 +786,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q: lambda: sp.ImmutableMatrix(
[Q(q[0], q[1]) for q in self.raw_value_rat2]
),
- MT_R: lambda: sp.ImmutableMatrix([R(r) for r in self.raw_value_float2]),
+ MT_R: lambda: sp.ImmutableMatrix(
+ [R(r, UI_FLOAT_PREC) for r in self.raw_value_float2]
+ ),
MT_C: lambda: sp.ImmutableMatrix(
- [c[0] + sp.I * c[1] for c in self.raw_value_complex2]
+ [
+ R(c[0], UI_FLOAT_PREC) + sp.I * R(c[1], UI_FLOAT_PREC)
+ for c in self.raw_value_complex2
+ ]
),
},
NS.Vec3: {
@@ -632,12 +801,21 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q: lambda: sp.ImmutableMatrix(
[Q(q[0], q[1]) for q in self.raw_value_rat3]
),
- MT_R: lambda: sp.ImmutableMatrix([R(r) for r in self.raw_value_float3]),
+ MT_R: lambda: sp.ImmutableMatrix(
+ [R(r, UI_FLOAT_PREC) for r in self.raw_value_float3]
+ ),
MT_C: lambda: sp.ImmutableMatrix(
- [c[0] + sp.I * c[1] for c in self.raw_value_complex3]
+ [
+ R(c[0], UI_FLOAT_PREC) + sp.I * R(c[1], UI_FLOAT_PREC)
+ for c in self.raw_value_complex3
+ ]
),
},
- }[self.size][self.mathtype]() * (self.unit if self.unit is not None else 1)
+ }[self.size][self.mathtype]()
+
+ if raw_value not in self.domain:
+ return ct.FlowSignal.FlowPending
+ return raw_value * self.unit_factor
@value.setter
def value(self, expr: spux.SympyExpr) -> None:
@@ -650,7 +828,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_value_spstr = sp.sstr(expr)
else:
NS = spux.NumberSize1D
- MT = spux.MathType
match (self.size, self.mathtype):
case (NS.Scalar, MT.Integer):
self.raw_value_int = self._to_raw_value(expr)
@@ -694,6 +871,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'unit',
'mathtype',
'size',
+ 'domain',
'steps',
'scaling',
'raw_min_sp',
@@ -723,10 +901,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
symbols=self.symbols,
)
- MT_Z = spux.MathType.Integer
- MT_Q = spux.MathType.Rational
- MT_R = spux.MathType.Real
- MT_C = spux.MathType.Complex
+ MT_Z = MT.Integer
+ MT_Q = MT.Rational
+ MT_R = MT.Real
+ MT_C = MT.Complex
Z = sp.Integer
Q = sp.Rational
R = sp.RealNumber
@@ -739,10 +917,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
bound[0] + sp.I * bound[1] for bound in self.raw_range_complex
],
}[self.mathtype]()
+ if min_bound not in self.domain or max_bound not in self.domain:
+ return ct.FlowSignal.FlowPending
return ct.RangeFlow(
- start=min_bound,
- stop=max_bound,
+ start=sp.Float(min_bound, 4),
+ stop=sp.Float(max_bound, 4),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
@@ -763,10 +943,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_max_spstr = sp.sstr(lazy_range.stop)
else:
- MT_Z = spux.MathType.Integer
- MT_Q = spux.MathType.Rational
- MT_R = spux.MathType.Real
- MT_C = spux.MathType.Complex
+ MT_Z = MT.Integer
+ MT_Q = MT.Rational
+ MT_R = MT.Real
+ MT_C = MT.Complex
unit = lazy_range.unit if lazy_range.unit is not None else 1
if self.mathtype == MT_Z:
@@ -807,6 +987,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
func_output=self.output_sym,
supports_jax=True,
)
+ return ct.FuncFlow(
+ func=lambda v: v,
+ func_args=[self.output_sym],
+ func_output=self.output_sym,
+ supports_jax=True,
+ )
return ct.FlowSignal.FlowPending
@@ -821,27 +1007,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
if self.output_sym is not None:
match self.active_kind:
- case ct.FlowKind.Value | ct.FlowKind.Func if (
- not ct.FlowSignal.check(self.value)
- ):
+ case FK.Value | FK.Func if (not ct.FlowSignal.check(self.value)):
return ct.ParamsFlow(
- arg_targets=[self.output_sym],
func_args=[self.value],
symbols=set(self.sorted_symbols),
)
- case ct.FlowKind.Range if self.sorted_symbols:
+ case FK.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
- case ct.FlowKind.Range if (
+ case FK.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.ParamsFlow(
- arg_targets=[self.output_sym],
func_args=[self.output_sym.sp_symbol_matsym],
symbols={self.output_sym},
- ).realize_partial({self.output_sym: self.lazy_range})
+ ).realize_partial(frozendict({self.output_sym: self.lazy_range}))
return ct.FlowSignal.FlowPending
@@ -861,17 +1043,17 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
if self.output_sym is not None:
match self.active_kind:
- case ct.FlowKind.Value | ct.FlowKind.Func:
+ case FK.Value | FK.Func:
return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym,
)
- case ct.FlowKind.Range if self.sorted_symbols:
+ case FK.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
- case ct.FlowKind.Range if (
+ case FK.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(
@@ -893,11 +1075,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
socket_type=self.socket_type,
active_kind=self.active_kind,
allow_out_to_in={
- ct.FlowKind.Func: ct.FlowKind.Value,
+ FK.Func: FK.Value,
},
allow_out_to_in_if_matches={
- ct.FlowKind.Value: (
- ct.FlowKind.Func,
+ FK.Value: (
+ FK.Func,
(
info.output.physical_type,
info.output.mathtype,
@@ -917,8 +1099,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
output_sym = self.output_sym
if output_sym is not None:
allow_out_to_in_if_matches = {
- ct.FlowKind.Value: (
- ct.FlowKind.Func,
+ FK.Value: (
+ FK.Func,
(
output_sym.physical_type,
output_sym.mathtype,
@@ -934,7 +1116,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
socket_type=self.socket_type,
active_kind=self.active_kind,
allow_out_to_in={
- ct.FlowKind.Func: ct.FlowKind.Value,
+ FK.Func: FK.Value,
},
allow_out_to_in_if_matches=allow_out_to_in_if_matches,
)
@@ -964,8 +1146,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes:
Whether information about the expression passing through a linked socket is shown is governed by `self.show_info_columns`.
"""
- if self.active_kind is ct.FlowKind.Func:
- info = self.compute_data(kind=ct.FlowKind.Info)
+ if self.active_kind is FK.Func:
+ info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
@@ -999,8 +1181,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes:
Whether information about the expression passing through a linked socket is shown is governed by `self.show_info_columns`.
"""
- if self.active_kind is ct.FlowKind.Func:
- info = self.compute_data(kind=ct.FlowKind.Info)
+ if self.active_kind is FK.Func:
+ info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
@@ -1054,7 +1236,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
else:
NS = spux.NumberSize1D
- MT = spux.MathType
match (self.size, self.mathtype):
case (NS.Scalar, MT.Integer):
col.prop(self, self.blfields['raw_value_int'], text='')
@@ -1115,10 +1296,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
col.prop(self, self.blfields['raw_max_spstr'], text='')
else:
- MT_Z = spux.MathType.Integer
- MT_Q = spux.MathType.Rational
- MT_R = spux.MathType.Real
- MT_C = spux.MathType.Complex
+ MT_Z = MT.Integer
+ MT_Q = MT.Rational
+ MT_R = MT.Real
+ MT_C = MT.Complex
if self.mathtype == MT_Z:
col.prop(self, self.blfields['raw_range_int'], text='')
elif self.mathtype == MT_Q:
@@ -1173,7 +1354,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
"""Visualize the `InfoFlow` information passing through the socket."""
if (
- self.active_kind is ct.FlowKind.Func
+ self.active_kind is FK.Func
and self.show_info_columns
and (self.is_linked or self.is_output)
):
@@ -1219,7 +1400,7 @@ class ExprSocketDef(base.SocketDef):
# Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
- mathtype: spux.MathType = spux.MathType.Real
+ mathtype: MT = MT.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
default_unit: spux.Unit | None = None
@@ -1227,8 +1408,6 @@ class ExprSocketDef(base.SocketDef):
# FlowKind: Value
default_value: spux.SympyExpr = 0
- abs_min: spux.SympyExpr | None = None
- abs_max: spux.SympyExpr | None = None
# FlowKind: Range
default_min: spux.SympyExpr = 0
@@ -1236,6 +1415,15 @@ class ExprSocketDef(base.SocketDef):
default_steps: int = 2
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
+ # Domain
+ abs_min: spux.SympyExpr | None = None
+ abs_max: spux.SympyExpr | None = None
+ abs_min_closed: bool = True
+ abs_max_closed: bool = True
+ abs_min_closed_im: bool = True
+ abs_max_closed_im: bool = True
+ exclude_zero: bool = False
+
# UI
show_name_selector: bool = False
show_func_ui: bool = True
@@ -1347,14 +1535,14 @@ class ExprSocketDef(base.SocketDef):
# Coerce Number -> Column 0-Vector
## -> TODO: We don't strictly know if default_value is a number.
if len(self.size.shape) == 1:
- self.default_value = self.default_value * sp.Matrix.ones(
+ self.default_value = self.default_value * sp.ImmutableMatrix.ones(
self.size.shape[0], 1
)
# Coerce Number -> 0-Matrix
## -> TODO: We don't strictly know if default_value is a number.
if len(self.size.shape) > 1:
- self.default_value = self.default_value * sp.Matrix.ones(
+ self.default_value = self.default_value * sp.ImmutableMatrix.ones(
*self.size.shape
)
@@ -1367,9 +1555,9 @@ class ExprSocketDef(base.SocketDef):
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`, after coersion to the correct Python type using `self.mathtype.coerce_compatible_pyobj()`.
Raises:
- ValueError: If `self.default_value` has no obvious, coerceable `spux.MathType` compatible with `self.mathtype`, as determined by `spux.MathType.has_mathtype`.
+ ValueError: If `self.default_value` has no obvious, coerceable `MT` compatible with `self.mathtype`, as determined by `MT.has_mathtype`.
"""
- mathtype_guide = spux.MathType.has_mathtype(self.default_value)
+ mathtype_guide = MT.has_mathtype(self.default_value)
# None: No Obvious Mathtype
if mathtype_guide is None:
@@ -1378,7 +1566,7 @@ class ExprSocketDef(base.SocketDef):
# PyType: Coerce from PyType
if mathtype_guide == 'pytype':
- dv_mathtype = spux.MathType.from_pytype(type(self.default_value))
+ dv_mathtype = MT.from_pytype(type(self.default_value))
if self.mathtype.is_compatible(dv_mathtype):
self.default_value = sp.S(
self.mathtype.coerce_compatible_pyobj(self.default_value)
@@ -1389,7 +1577,7 @@ class ExprSocketDef(base.SocketDef):
# Expr: Merely Check MathType Compatibility
if mathtype_guide == 'expr':
- dv_mathtype = spux.MathType.from_expr(self.default_value)
+ dv_mathtype = MT.from_expr(self.default_value)
if not self.mathtype.is_compatible(dv_mathtype):
msg = f'ExprSocket: Mathtype {dv_mathtype} of default value expression {self.default_value} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg)
@@ -1416,11 +1604,11 @@ class ExprSocketDef(base.SocketDef):
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`.
Raises:
- ValueError: If `self.default_value` has no obvious `spux.MathType`, as determined by `spux.MathType.has_mathtype`.
+ ValueError: If `self.default_value` has no obvious `MT`, as determined by `MT.has_mathtype`.
"""
new_bounds = [None, None]
for i, bound in enumerate([self.default_min, self.default_max]):
- mathtype_guide = spux.MathType.has_mathtype(bound)
+ mathtype_guide = MT.has_mathtype(bound)
# None: No Obvious Mathtype
if mathtype_guide is None:
@@ -1429,7 +1617,7 @@ class ExprSocketDef(base.SocketDef):
# PyType: Coerce from PyType
if mathtype_guide == 'pytype':
- dv_mathtype = spux.MathType.from_pytype(type(bound))
+ dv_mathtype = MT.from_pytype(type(bound))
if self.mathtype.is_compatible(dv_mathtype):
new_bounds[i] = sp.S(self.mathtype.coerce_compatible_pyobj(bound))
else:
@@ -1438,18 +1626,18 @@ class ExprSocketDef(base.SocketDef):
# Expr: Merely Check MathType Compatibility
if mathtype_guide == 'expr':
- dv_mathtype = spux.MathType.from_expr(bound)
+ dv_mathtype = MT.from_expr(bound)
if not self.mathtype.is_compatible(dv_mathtype):
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg)
# Coerce from Infinite
if isinstance(bound, spux.SympyType):
- if bound.is_infinite and self.mathtype is spux.MathType.Integer:
+ if bound.is_infinite and self.mathtype is MT.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
- if bound.is_infinite and self.mathtype is spux.MathType.Rational:
+ if bound.is_infinite and self.mathtype is MT.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
- if bound.is_infinite and self.mathtype is spux.MathType.Real:
+ if bound.is_infinite and self.mathtype is MT.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None:
@@ -1468,10 +1656,7 @@ class ExprSocketDef(base.SocketDef):
"""
# Check ActiveKind and Size
## -> NOTE: This doesn't protect against dynamic changes to either.
- if (
- self.active_kind is ct.FlowKind.Range
- and self.size is not spux.NumberSize1D.Scalar
- ):
+ if self.active_kind is FK.Range and self.size is not spux.NumberSize1D.Scalar:
msg = "Can't have a non-Scalar size when Range is set as the active kind."
raise ValueError(msg)
@@ -1524,8 +1709,61 @@ class ExprSocketDef(base.SocketDef):
bl_socket.size = self.size
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
+ bl_socket.active_unit = bl_cache.Signal.ResetEnumItems
+ bl_socket.unit = bl_cache.Signal.InvalidateCache
+ bl_socket.unit_factor = bl_cache.Signal.InvalidateCache
bl_socket.symbols = self.default_symbols
+ # Domain
+ bl_socket.exclude_zero = self.exclude_zero
+
+ if self.abs_min is None:
+ bl_socket.abs_min_infinite = True
+ bl_socket.abs_min_infinite_im = True
+ else:
+ bl_socket.abs_min_closed = self.abs_min_closed
+
+ if self.abs_max is None:
+ bl_socket.abs_max_infinite = True
+ bl_socket.abs_max_infinite_im = True
+ else:
+ bl_socket.abs_max_closed = self.abs_max_closed
+
+ match self.mathtype:
+ case MT.Integer if self.abs_min is not None:
+ bl_socket.abs_min_int = int(self.abs_min)
+ case MT.Integer if self.abs_max is not None:
+ bl_socket.abs_max_int = int(self.abs_max)
+
+ case MT.Rational if self.abs_min is not None:
+ bl_socket.abs_min_rat = (
+ self.abs_min.numerator,
+ self.abs_min.denominator,
+ )
+ case MT.Rational if self.abs_max is not None:
+ bl_socket.abs_max_rat = (
+ self.abs_max.numerator,
+ self.abs_max.denominator,
+ )
+
+ case MT.Real if self.abs_min is not None:
+ bl_socket.abs_min_float = float(self.abs_min)
+ case MT.Real if self.abs_max is not None:
+ bl_socket.abs_max_float = float(self.abs_max)
+
+ case MT.Complex if self.abs_min is not None:
+ bl_socket.abs_min_complex = (
+ float(sp.re(self.abs_min)),
+ float(sp.im(self.abs_min)),
+ )
+ bl_socket.abs_min_closed_im = self.abs_min_closed_im
+ case MT.Complex if self.abs_max is not None:
+ bl_socket.abs_max_complex = (
+ float(sp.re(self.abs_max)),
+ float(sp.im(self.abs_max)),
+ )
+ bl_socket.abs_max_closed_im = self.abs_max_closed_im
+
# FlowKind.Value
## -> We must take units into account when setting bl_socket.value
if self.physical_type is not spux.PhysicalType.NonPhysical:
@@ -1570,7 +1808,9 @@ class ExprSocketDef(base.SocketDef):
and cmp('show_func_ui')
and cmp('show_info_columns')
and cmp('show_name_selector')
+ and cmp('show_name_selector')
and bl_socket.use_info_draw
+ ## TODO: Include domain?
)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
index aa41351..7dea86c 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py
@@ -14,16 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import math
-
import bpy
-import jax.numpy as jnp
import scipy as sc
+import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
-import tidy3d.plugins.adjoint as tdadj
-from blender_maxwell.utils import bl_cache, logger
+from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
@@ -51,49 +48,44 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
####################
eps_rel: tuple[float, float] = bl_cache.BLField((1.0, 0.0), float_prec=2)
- differentiable: bool = bl_cache.BLField(False)
-
####################
# - FlowKinds
####################
- @bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
+ @bl_cache.cached_bl_property(depends_on={'eps_rel'})
def value(self) -> td.Medium:
eps_r_re = self.eps_rel[0]
- conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
+ # conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
- if self.differentiable:
- return tdadj.JaxMedium(
- permittivity_jax=jnp.array(eps_r_re, dtype=float),
- conductivity_jax=jnp.array(conductivity, dtype=float),
- )
return td.Medium(
permittivity=eps_r_re,
- conductivity=conductivity,
+ # conductivity=conductivity,
)
@value.setter
def value(self, eps_rel: tuple[float, float]) -> None:
self.eps_rel = eps_rel
- @bl_cache.cached_bl_property(depends_on={'value'})
+ @bl_cache.cached_bl_property()
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
- func=lambda: self.value,
- supports_jax=self.differentiable,
+ func=lambda eps_r_re, eps_r_im: td.Medium(
+ permittivity=eps_r_re,
+ # conductivity=FIXED_FREQ * eps_r_im,
+ ),
+ func_args=[sim_symbols.rel_eps_re(None), sim_symbols.rel_eps_im(None)],
)
- @bl_cache.cached_bl_property(depends_on={'differentiable'})
+ @bl_cache.cached_bl_property(depends_on={'eps_rel'})
def params(self) -> ct.FuncFlow:
+ return ct.ParamsFlow(
+ func_args=[sp.S(self.eps_rel[0]), sp.S(self.eps_rel[1])],
+ )
return ct.ParamsFlow()
####################
# - UI
####################
def draw_value(self, col: bpy.types.UILayout) -> None:
- col.prop(
- self, self.blfields['differentiable'], text='Differentiable', toggle=True
- )
- col.separator()
split = col.split(factor=0.25, align=False)
_col = split.column(align=True)
diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/tidy3d/cloud_task.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/tidy3d/cloud_task.py
index 28fe9e8..11ee1d2 100644
--- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/tidy3d/cloud_task.py
+++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/tidy3d/cloud_task.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+"""Implements `Tidy3DCloudTaskBLSocket`."""
+
import enum
import bpy
@@ -29,49 +31,59 @@ from .. import base
# - Operators
####################
class ReloadFolderList(bpy.types.Operator):
+ """Reload the list of available folders."""
+
bl_idname = ct.OperatorType.SocketReloadCloudFolderList
bl_label = 'Reload Tidy3D Folder List'
bl_description = 'Reload the the cached Tidy3D folder list'
@classmethod
def poll(cls, context):
+ """Allow reloading the folder list when online, authenticated, and attached to a `Tidy3DCloudTask` socket."""
return (
tdcloud.IS_ONLINE
and tdcloud.IS_AUTHENTICATED
and hasattr(context, 'socket')
and hasattr(context.socket, 'socket_type')
- and context.socket.socket_type == ct.SocketType.Tidy3DCloudTask
+ and context.socket.socket_type is ct.SocketType.Tidy3DCloudTask
)
def execute(self, context):
+ """Update the folder list, as well as any tasks attached to any existing folder ID."""
bl_socket = context.socket
tdcloud.TidyCloudFolders.update_folders()
- tdcloud.TidyCloudTasks.update_tasks(bl_socket.existing_folder_id)
-
bl_socket.existing_folder_id = bl_cache.Signal.ResetEnumItems
bl_socket.existing_folder_id = bl_cache.Signal.InvalidateCache
+
+ if bl_socket.existing_folder_id is not None:
+ tdcloud.TidyCloudTasks.update_tasks(bl_socket.existing_folder_id)
+
bl_socket.existing_task_id = bl_cache.Signal.ResetEnumItems
bl_socket.existing_task_id = bl_cache.Signal.InvalidateCache
-
return {'FINISHED'}
class Authenticate(bpy.types.Operator):
+ """Authenticate with the Tidy3D API."""
+
bl_idname = ct.OperatorType.SocketCloudAuthenticate
bl_label = 'Authenticate Tidy3D'
bl_description = 'Authenticate the Tidy3D Web API from a Cloud Task socket'
@classmethod
def poll(cls, context):
+ """Allow authenticating when online, not authenticated, and attached to a `Tidy3DCloudTask` socket."""
return (
- not tdcloud.IS_AUTHENTICATED
+ tdcloud.IS_ONLINE
+ and not tdcloud.IS_AUTHENTICATED
and hasattr(context, 'socket')
and hasattr(context.socket, 'socket_type')
- and context.socket.socket_type == ct.SocketType.Tidy3DCloudTask
+ and context.socket.socket_type is ct.SocketType.Tidy3DCloudTask
)
def execute(self, context):
+ """Try authenticating the socket with the web service."""
bl_socket = context.socket
if not tdcloud.check_authentication():
@@ -104,11 +116,14 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
bl_label = 'Tidy3D Cloud Task'
####################
- # - Properties
+ # - Properties: Authentication
####################
api_key: str = bl_cache.BLField('', str_secret=True)
- should_exist: bool = bl_cache.BLField(False)
+ ####################
+ # - Properties: Existance
+ ####################
+ should_exist: bool = bl_cache.BLField(False)
new_task_name: str = bl_cache.BLField('')
####################
@@ -119,10 +134,11 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
)
def search_cloud_folders(self) -> list[ct.BLEnumElement]:
+ """Get all Tidy3D cloud folders."""
if tdcloud.IS_AUTHENTICATED:
return [
(
- cloud_folder.folder_id,
+ folder_id,
cloud_folder.folder_name,
f'Folder {cloud_folder.folder_name} (ID={folder_id})',
'',
@@ -139,10 +155,12 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
# - Properties: Cloud Tasks
####################
existing_task_id: enum.StrEnum = bl_cache.BLField(
- enum_cb=lambda self, _: self.search_cloud_tasks()
+ enum_cb=lambda self, _: self.search_cloud_tasks(),
+ cb_depends_on={'existing_folder_id'},
)
def search_cloud_tasks(self) -> list[ct.BLEnumElement]:
+ """Retrieve all Tidy3D cloud folders from the cloud service."""
if self.existing_folder_id is None or not tdcloud.IS_AUTHENTICATED:
return []
@@ -200,6 +218,7 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
####################
@bl_cache.cached_bl_property(depends_on={'active_kind', 'should_exist'})
def capabilities(self) -> ct.CapabilitiesFlow:
+ """Cloud task sockets are compatible by presumtion of existance."""
return ct.CapabilitiesFlow(
socket_type=self.socket_type,
active_kind=self.active_kind,
@@ -217,48 +236,41 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
def value(
self,
) -> ct.NewSimCloudTask | tdcloud.CloudTask | ct.FlowSignal:
+ """Return either the existing cloud task, or an object describing how to make it."""
if tdcloud.IS_AUTHENTICATED:
# Retrieve Folder
cloud_folder = tdcloud.TidyCloudFolders.folders().get(
self.existing_folder_id
)
- if cloud_folder is None:
+ if cloud_folder is not None:
+ # Case: New Task
+ if not self.should_exist:
+ return ct.NewSimCloudTask(
+ task_name=self.new_task_name, cloud_folder=cloud_folder
+ )
+
+ # Case: Existing Task
+ if self.existing_task_id is not None:
+ cloud_task = tdcloud.TidyCloudTasks.tasks(cloud_folder)[
+ self.existing_task_id
+ ]
+ if cloud_task is not None:
+ return cloud_task
+ else:
return ct.FlowSignal.NoFlow ## Folder deleted somewhere else
-
- # Case: New Task
- if not self.should_exist:
- return ct.NewSimCloudTask(
- task_name=self.new_task_name, cloud_folder=cloud_folder
- )
-
- # Case: Existing Task
- if self.existing_task_id is not None:
- cloud_task = tdcloud.TidyCloudTasks.tasks(cloud_folder).get(
- self.existing_task_id
- )
- if cloud_folder is None:
- return ct.FlowSignal.NoFlow ## Task deleted somewhere else
-
- return cloud_task
-
return ct.FlowSignal.FlowPending
####################
# - UI
####################
- def draw_label_row(self, row: bpy.types.UILayout, text: str):
- row.label(text=text)
-
- auth_icon = 'LOCKVIEW_ON' if tdcloud.IS_AUTHENTICATED else 'LOCKVIEW_OFF'
- row.label(text='', icon=auth_icon)
-
def draw_prelock(
self,
- context: bpy.types.Context,
+ _: bpy.types.Context,
col: bpy.types.UILayout,
- node: bpy.types.Node,
- text: str,
+ node: bpy.types.Node, # noqa: ARG002
+ text: str, # noqa: ARG002
) -> None:
+ """Draw a lock-immune interface for authenticating with the Tidy3D API, when the authentication is invalid."""
if not tdcloud.IS_AUTHENTICATED:
row = col.row()
row.alignment = 'CENTER'
@@ -274,6 +286,7 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
)
def draw_value(self, col: bpy.types.UILayout) -> None:
+ """When authenticated, draw the node UI."""
if not tdcloud.IS_AUTHENTICATED:
return
@@ -306,11 +319,14 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
# - Socket Configuration
####################
class Tidy3DCloudTaskSocketDef(base.SocketDef):
+ """Declarative object guiding the creation of a `Tidy3DCloudTask` socket."""
+
socket_type: ct.SocketType = ct.SocketType.Tidy3DCloudTask
should_exist: bool
def init(self, bl_socket: Tidy3DCloudTaskBLSocket) -> None:
+ """Initialize the passed socket with the properties of this object."""
bl_socket.should_exist = self.should_exist
bl_socket.use_prelock = True
diff --git a/src/blender_maxwell/services/tdcloud.py b/src/blender_maxwell/services/tdcloud.py
index ee6ed20..8965be9 100644
--- a/src/blender_maxwell/services/tdcloud.py
+++ b/src/blender_maxwell/services/tdcloud.py
@@ -28,6 +28,7 @@ import urllib
from dataclasses import dataclass
from pathlib import Path
+import bpy
import tidy3d as td
import tidy3d.web as td_web
@@ -495,3 +496,28 @@ class TidyCloudTasks:
raise RuntimeError(msg) from ex
return cls.update_task(cloud_task)
+
+
+####################
+# - Blender UI Integration
+####################
+def draw_cloud_status(layout: bpy.types.UILayout) -> None:
+ """Draw up-to-date information about the connection to the Tidy3D cloud to a Blender UI."""
+ # Connection Info
+ auth_icon = 'CHECKBOX_HLT' if IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
+ conn_icon = 'CHECKBOX_HLT' if IS_ONLINE else 'CHECKBOX_DEHLT'
+
+ row = layout.row()
+ row.alignment = 'CENTER'
+ row.label(text='Cloud Status')
+
+ box = layout.box()
+ split = box.split(factor=0.85)
+
+ col = split.column(align=False)
+ col.label(text='Authed')
+ col.label(text='Connected')
+
+ col = split.column(align=False)
+ col.label(icon=auth_icon)
+ col.label(icon=conn_icon)
diff --git a/src/blender_maxwell/utils/bl_cache/bl_prop_type.py b/src/blender_maxwell/utils/bl_cache/bl_prop_type.py
index ca2961e..578105c 100644
--- a/src/blender_maxwell/utils/bl_cache/bl_prop_type.py
+++ b/src/blender_maxwell/utils/bl_cache/bl_prop_type.py
@@ -582,7 +582,7 @@ class BLPropType(enum.StrEnum):
return str(value)
# Single Enum: Coerce to set[str]
- case BPT.SetEnum | BPT.SetDynEnum if isinstance(value, set):
+ case BPT.SetEnum | BPT.SetDynEnum if isinstance(value, set | frozenset):
return {str(v) for v in value}
# BLPointer: Don't Alter
@@ -594,7 +594,7 @@ class BLPropType(enum.StrEnum):
case BPT.Serialized:
return serialize.encode(value).decode('utf-8')
- msg = f'{self}: No encoder defined for argument {value}'
+ msg = f'{self}: No encoder defined for argument {value} (type={type(value)})'
raise NotImplementedError(msg)
####################
@@ -603,7 +603,7 @@ class BLPropType(enum.StrEnum):
def decode(self, raw_value: typ.Any, obj_type: type) -> typ.Any: # noqa: PLR0911
"""Transform a raw value from a form read directly from the Blender property returned by `self.bl_type`, to its intended value of approximate type `obj_type`.
- Notes:
+ Ntes:
`obj_type` is only a hint - for example, `obj_type = enum.StrEnum` is an indicator for a dynamic enum.
Its purpose is to guide ex. sizing and `StrEnum` coersion, not to guarantee a particular output type.
@@ -661,7 +661,7 @@ class BLPropType(enum.StrEnum):
## -> This happens independent of whether there's a enum_cb.
case BPT.SingleEnum if isinstance(raw_value, str):
return obj_type(raw_value)
- case BPT.SetEnum if isinstance(raw_value, set):
+ case BPT.SetEnum if isinstance(raw_value, set | frozenset):
SubStrEnum = typ.get_args(obj_type)[0]
return {SubStrEnum(v) for v in raw_value}
@@ -670,7 +670,7 @@ class BLPropType(enum.StrEnum):
## -> All dynamic enums have an enum_cb, but this is merely a symptom of ^.
case BPT.SingleDynEnum if isinstance(raw_value, str):
return raw_value
- case BPT.SetDynEnum if isinstance(raw_value, set):
+ case BPT.SetDynEnum if isinstance(raw_value, set | frozenset):
return raw_value
# BLPointer
diff --git a/src/blender_maxwell/utils/frozendict.py b/src/blender_maxwell/utils/frozendict.py
new file mode 100644
index 0000000..a707fc6
--- /dev/null
+++ b/src/blender_maxwell/utils/frozendict.py
@@ -0,0 +1,57 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements a `pydantic`-compatible field, `FrozenDict`, which encapsulates a `frozendict` in a serializable way, with semantics identical to `dict`."""
+
+import typing as typ
+
+import pydantic as pyd
+from frozendict import deepfreeze, frozendict
+from pydantic_core import core_schema as pyd_core_schema
+
+
+class _PydanticFrozenDictAnnotation:
+ """Annotated validator providing interoperability between `frozendict` and `pydantic` models.
+
+ Semantics are almost identical to `dict`, except for a chained conversion to `frozendict`.
+ """
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: typ.Any, handler: pyd.GetCoreSchemaHandler
+ ) -> pyd_core_schema.CoreSchema:
+ def validate_from_dict(d: dict | frozendict) -> frozendict:
+ return frozendict(d)
+
+ frozendict_schema = pyd_core_schema.chain_schema(
+ [
+ handler.generate_schema(dict[*typ.get_args(source_type)]),
+ pyd_core_schema.no_info_plain_validator_function(validate_from_dict),
+ pyd_core_schema.is_instance_schema(frozendict),
+ ]
+ )
+ return pyd_core_schema.json_or_python_schema(
+ json_schema=frozendict_schema,
+ python_schema=frozendict_schema,
+ serialization=pyd_core_schema.plain_serializer_function_ser_schema(dict),
+ )
+
+
+_K = typ.TypeVar('_K')
+_V = typ.TypeVar('_V')
+FrozenDict = typ.Annotated[frozendict[_K, _V], _PydanticFrozenDictAnnotation]
+
+__all__ = ['deepfreeze', 'frozendict', 'FrozenDict']
diff --git a/src/blender_maxwell/utils/image_ops.py b/src/blender_maxwell/utils/image_ops.py
index 68820c3..3398c67 100644
--- a/src/blender_maxwell/utils/image_ops.py
+++ b/src/blender_maxwell/utils/image_ops.py
@@ -17,7 +17,6 @@
"""Useful image processing operations for use in the addon."""
import enum
-import functools
import typing as typ
import jax
@@ -28,10 +27,11 @@ import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg
import matplotlib.figure
import seaborn as sns
+import sympy as sp
from blender_maxwell import contracts as ct
-from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
sns.set_theme()
@@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
####################
# - MPL Helpers
####################
-@functools.lru_cache(maxsize=4)
+# @functools.lru_cache(maxsize=4)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
@@ -154,30 +154,51 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
####################
# - Plotters
####################
+def _parse_val(val):
+ return spux.sp_to_str(sp.S(val).n(2))
+
+
+def pinned_labels(pinned_data) -> str:
+ return (
+ '\n'
+ + ', '.join(
+ [
+ f'{sym.name_pretty}:' + _parse_val(val)
+ for sym, val in pinned_data.items()
+ ]
+ )
+ # + ']'
+ )
+
+
# (ℤ) -> ℝ
def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
- x_sym, y_sym = list(data.keys())
+ x_sym, y_sym, pinned = list(data.keys())
ax.boxplot([data[y_sym]])
- ax.set_title(f'{x_sym.name_pretty} → {y_sym.name_pretty}')
+ ax.set_title(
+ f'{x_sym.name_pretty} → {y_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
def plot_bar(data, ax: mpl_ax.Axis) -> None:
- x_sym, heights_sym = list(data.keys())
+ x_sym, heights_sym, pinned = list(data.keys())
p = ax.bar(data[x_sym], data[heights_sym])
ax.bar_label(p, label_type='center')
- ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
+ ax.set_title(
+ f'{x_sym.name_pretty} → {heights_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(heights_sym.plot_label)
# (ℝ) -> ℝ (| sometimes complex)
def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
- x_sym, y_sym = list(data.keys())
+ x_sym, y_sym, pinned = list(data.keys())
if y_sym.mathtype is spux.MathType.Complex:
ax.plot(data[x_sym], data[y_sym].real, label='ℝ')
@@ -185,38 +206,47 @@ def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
ax.legend()
ax.plot(data[x_sym], data[y_sym])
+ ax.set_title(
+ f'{x_sym.name_pretty} → {y_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_title(f'{x_sym.name_pretty} → {y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
- x_sym, y_sym = list(data.keys())
+ x_sym, y_sym, pinned = list(data.keys())
ax.scatter(data[x_sym], data[y_sym])
- ax.set_title(f'{x_sym.name_pretty} → {y_sym.name_pretty}')
+ ax.set_title(
+ f'{x_sym.name_pretty} → {y_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
# (ℝ, ℤ) -> ℝ
def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
- x_sym, label_sym, y_sym = list(data.keys())
+ x_sym, label_sym, y_sym, pinned = list(data.keys())
for i, label in enumerate(data[label_sym]):
ax.plot(data[x_sym], data[y_sym][:, i], label=label)
- ax.set_title(f'{x_sym.name_pretty} → {y_sym.name_pretty}')
+ ax.set_title(
+ f'{x_sym.name_pretty} → {y_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
ax.legend()
def plot_filled_curves_2d(data, ax: mpl_ax.Axis) -> None:
- x_sym, _, y_sym = list(data.keys(data))
+ x_sym, _, y_sym, pinned = list(data.keys(data))
ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
- ax.set_title(f'{x_sym.name_pretty} → {y_sym.name_pretty}')
+ ax.set_title(
+ f'{x_sym.name_pretty} → {y_sym.name_pretty} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
ax.legend()
@@ -224,12 +254,14 @@ def plot_filled_curves_2d(data, ax: mpl_ax.Axis) -> None:
# (ℝ, ℝ) -> ℝ
def plot_heatmap_2d(data, ax: mpl_ax.Axis) -> None:
- x_sym, y_sym, c_sym = list(data.keys())
+ x_sym, y_sym, c_sym, pinned = list(data.keys())
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
- ax.figure.colorbar(heatmap, cax=ax)
+ # ax.figure.colorbar(heatmap, ax=ax)
- ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) → {c_sym.plot_label}')
+ ax.set_title(
+ f'({x_sym.name_pretty}, {y_sym.name_pretty}) → {c_sym.plot_label} {pinned_labels(data[pinned])}'
+ )
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()
diff --git a/src/blender_maxwell/utils/jaxarray.py b/src/blender_maxwell/utils/jaxarray.py
new file mode 100644
index 0000000..4790c3e
--- /dev/null
+++ b/src/blender_maxwell/utils/jaxarray.py
@@ -0,0 +1,142 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements a `pydantic`-compatible field, `JaxArray`, which encapsulates a `jax.Array` in a serializable way."""
+
+import base64
+import io
+import typing as typ
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pydantic as pyd
+from pydantic_core import core_schema as pyd_core_schema
+
+
+####################
+# - Simple JAX Array
+####################
+class _JaxArray:
+ """Annotated validator providing interoperability between `jax.Array` and `pydantic` models.
+
+ Serializes to base64 bytes, for compatibility.
+ """
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: typ.Any, handler: pyd.GetCoreSchemaHandler
+ ) -> pyd_core_schema.CoreSchema:
+ def validate_from_any(
+ raw_array: bytes | jax.Array | list | tuple,
+ ) -> jax.Array:
+ if isinstance(raw_array, np.ndarray):
+ return jnp.array(raw_array)
+
+ if isinstance(raw_array, jax.Array):
+ return raw_array
+
+ if isinstance(raw_array, bytes | str):
+ with io.BytesIO() as memfile:
+ memfile.write(base64.b64decode(raw_array.encode('utf-8')))
+ memfile.seek(0)
+ return jnp.load(memfile)
+
+ if isinstance(raw_array, list | tuple):
+ return jnp.array(raw_array)
+
+ raise TypeError
+
+ # def make_hashable(array: jax.Array) -> HashableJaxArray:
+ # return HashableJaxArray(array)
+
+ def serialize_to_bytes(array: jax.Array) -> bytes:
+ with io.BytesIO() as memfile:
+ jnp.save(memfile, array)
+ return base64.b64encode(memfile.getvalue())
+
+ jnp_array_schema = pyd_core_schema.chain_schema(
+ [
+ pyd_core_schema.no_info_plain_validator_function(validate_from_any),
+ # pyd_core_schema.no_info_plain_validator_function(make_hashable),
+ pyd_core_schema.is_instance_schema(jax.Array),
+ ]
+ )
+ return pyd_core_schema.json_or_python_schema(
+ json_schema=jnp_array_schema,
+ python_schema=jnp_array_schema,
+ serialization=pyd_core_schema.plain_serializer_function_ser_schema(
+ serialize_to_bytes
+ ),
+ )
+
+
+JaxArray = typ.Annotated[jax.Array, _JaxArray]
+
+
+####################
+# - Hashable JAX Array as-Bytes
+####################
+class _JaxArrayBytes:
+ """Annotated validator providing interoperability between `jax.Array` and `pydantic` models.
+
+ Serializes to base64 bytes, for compatibility.
+ """
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: typ.Any, handler: pyd.GetCoreSchemaHandler
+ ) -> pyd_core_schema.CoreSchema:
+ def validate_from_any(
+ raw_array: bytes | jax.Array | list | tuple,
+ ) -> jax.Array:
+ if isinstance(raw_array, np.ndarray):
+ return jnp.array(raw_array)
+
+ if isinstance(raw_array, jax.Array):
+ return raw_array
+
+ if isinstance(raw_array, bytes | str):
+ with io.BytesIO() as memfile:
+ memfile.write(base64.b64decode(raw_array.encode('utf-8')))
+ memfile.seek(0)
+ return jnp.load(memfile)
+
+ if isinstance(raw_array, list | tuple):
+ return jnp.array(raw_array)
+
+ raise TypeError
+
+ def to_bytes(array: jax.Array) -> bytes:
+ with io.BytesIO() as memfile:
+ jnp.save(memfile, array)
+ return base64.b64encode(memfile.getvalue())
+
+ jnp_array_bytes_schema = pyd_core_schema.chain_schema(
+ [
+ pyd_core_schema.no_info_plain_validator_function(validate_from_any),
+ pyd_core_schema.no_info_plain_validator_function(to_bytes),
+ pyd_core_schema.is_instance_schema(bytes),
+ ]
+ )
+ return pyd_core_schema.json_or_python_schema(
+ json_schema=jnp_array_bytes_schema,
+ python_schema=jnp_array_bytes_schema,
+ serialization=pyd_core_schema.plain_serializer_function_ser_schema(bytes),
+ )
+
+
+JaxArrayBytes = typ.Annotated[bytes, _JaxArrayBytes]
diff --git a/src/blender_maxwell/utils/lru_method.py b/src/blender_maxwell/utils/lru_method.py
new file mode 100644
index 0000000..7743340
--- /dev/null
+++ b/src/blender_maxwell/utils/lru_method.py
@@ -0,0 +1,38 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import functools
+import weakref
+
+
+def method_lru(maxsize=2048, typed=False):
+ """LRU for methods.
+
+ Uses a weak reference to `self` to support memoized methods without memory leaks.
+ """
+
+ def wrapped_method(method):
+ @functools.lru_cache(maxsize, typed)
+ def _method(_self, *args, **kwargs):
+ return method(_self(), *args, **kwargs)
+
+ @functools.wraps(method)
+ def inner_method(self, *args, **kwargs):
+ return _method(weakref.ref(self), *args, **kwargs)
+
+ return inner_method
+
+ return wrapped_method
diff --git a/src/blender_maxwell/utils/serialize.py b/src/blender_maxwell/utils/serialize.py
index b58aebc..30e34be 100644
--- a/src/blender_maxwell/utils/serialize.py
+++ b/src/blender_maxwell/utils/serialize.py
@@ -14,7 +14,30 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-"""Robust serialization tool for use in the addon.
+"""A fast, robust `msgspec`-based serialization tool providing for string-based persistance of many objects.
+
+Blender provides for strong persistence guarantees based on its `bpy.types.Property` system.
+In essence, properties are defined on instances of objects like nodes and sockets, and in turn, these properties can help ex. drive update chains, are **persisted on save**, and more.
+
+The problem is fundamentally one of type support: Only "natural" types like `bool` `int`, `float`, `str`, and particular variants thereof are supported, with notable inclusions in the form of 1D/2D vectors and internal Blender pointers (which _are also persisted_).
+While this forms the extent of UI support, we do extremely often want to persist things that _aren't_ one of these blessed types: At the very least things like sympy types, immutable `pydantic` models, Tidy3D objects, and more.
+We want these "special" types to have the same guarantees as Blender's builtin properties.
+
+This brings us to the intersection of `msgspec`, `pydantic`, and `bpy.props.StringProperty`.
+
+- With a string-based property, we can "simply" serialize whatever object we want to persist, then later deserialize it into its original form.
+- The property access pattern is often in a hot-loop where that same property is accessed, so even with a cache layer above it, we cannot afford to wait ex. `100ms` per a one-way (de)serialization operation.
+- `pydantic` (especially V2) provides a very sane story for _many_ models, but fails in edge cases: It is simply too flexible, providing few robustness and completeness guarantees when it comes to generic serialization, while also demanding complete conformance to its `BaseModel` schema to do anything at all, which does nothing to cover the use-case of transparent type-driven serialization. To boot, its speed is fundamentally great, but still sometimes lacking in nuanced ways.
+- Conversely, `msgspec` is a far, far simpler approach to _best-in-class_, type-driven serialization of _almost_ all natural Python types. It has several very important caveats, but it also supports defining custom encoding/decoding as fallbacks to normal operation.
+
+Therefore, this module provides custom wrappers on top of `msgspec`, which are tailored to the use of several common types, including specially-enabled `pydantic` models and any `tidy3d` model encapsulated by its `dict(*)`.
+We standardize on `json`, which can be easily inserted into an internal `bpy.props.StringProperty` for persistance, with access times very low.
+
+What else did we consider?
+
+- Direct: We tried, quite thoroughly, to keep serialization of arbitrary objects a specialized use case. The result was thousands of lines of unbearably slow, error-prone boilerplate with severe limitations.
+- `json`: The standard-libary module `json` is rather inflexible, far too slow for our use case, and has no mechanisms for hooking custom objects into it.
+- Use of `MsgPack`: Unfortunately, while `bpy.props.StringProperty` does have a "bytes" mode, it refuses to encode arbitrary non-UTF8 bytes. Therefore, the formal binary `MsgPack` format is out of the question, though it is preferrable in almost every other context due to both density, speed, and flexibility.
Attributes:
NaiveEncodableType: See for details.
@@ -25,11 +48,14 @@ import datetime as dt
import decimal
import enum
import functools
+import json
import typing as typ
import uuid
import msgspec
+import numpy as np
import sympy as sp
+import tidy3d as td
from . import logger
from . import sympy_extra as spux
@@ -101,6 +127,7 @@ class TypeID(enum.StrEnum):
SocketDef: str = '!type=socketdef'
SimSymbol: str = '!type=simsymbol'
ManagedObj: str = '!type=managedobj'
+ Tidy3DObj: str = '!type=tidy3dobj'
NaiveRepresentation: typ.TypeAlias = list[TypeID, str | None, typ.Any]
@@ -131,6 +158,9 @@ def _enc_hook(obj: typ.Any) -> NaivelyEncodableType:
if isinstance(obj, spux.SympyType):
return ['!type=sympytype', None, sp.srepr(obj)]
+ if isinstance(obj, td.components.base.Tidy3dBaseModel):
+ return ['!type=tidy3dobj', None, obj._json()]
+
if hasattr(obj, 'dump_as_msgspec'):
return obj.dump_as_msgspec()
@@ -161,6 +191,14 @@ def _dec_hook(_type: type, obj: NaivelyEncodableType) -> typ.Any:
obj_value = obj[2]
return sp.sympify(obj_value).subs(spux.UNIT_BY_SYMBOL)
+ if (
+ issubclass(_type, td.components.base.Tidy3dBaseModel)
+ and is_representation(obj)
+ and obj[0] == TypeID.Tidy3DObj
+ ):
+ obj_json = obj[2]
+ return _type.parse_obj(json.loads(obj_json))
+
if hasattr(_type, 'parse_as_msgspec') and (
is_representation(obj)
and obj[0] in [TypeID.SocketDef, TypeID.ManagedObj, TypeID.SimSymbol]
diff --git a/src/blender_maxwell/utils/sim_symbols/__init__.py b/src/blender_maxwell/utils/sim_symbols/__init__.py
new file mode 100644
index 0000000..e6aa1e9
--- /dev/null
+++ b/src/blender_maxwell/utils/sim_symbols/__init__.py
@@ -0,0 +1,97 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Declares a useful, flexible symbolic representation."""
+
+from .common import (
+ CommonSimSymbol,
+ ang_phi,
+ ang_r,
+ ang_theta,
+ diff_order_x,
+ diff_order_y,
+ dir_x,
+ dir_y,
+ dir_z,
+ field_e,
+ field_ex,
+ field_ey,
+ field_ez,
+ field_h,
+ field_hx,
+ field_hy,
+ field_hz,
+ flux,
+ freq,
+ idx,
+ rel_eps_im,
+ rel_eps_re,
+ sim_axis_idx,
+ space_x,
+ space_y,
+ space_z,
+ t,
+ wl,
+)
+from .name import SimSymbolName
+from .sim_symbol import SimSymbol
+from .utils import (
+ float_max,
+ float_min,
+ int_max,
+ int_min,
+ mk_interval,
+ unicode_superscript,
+)
+
+__all__ = [
+ 'CommonSimSymbol',
+ 'idx',
+ 'rel_eps_im',
+ 'rel_eps_re',
+ 'sim_axis_idx',
+ 't',
+ 'wl',
+ 'freq',
+ 'space_x',
+ 'space_y',
+ 'space_z',
+ 'dir_x',
+ 'dir_y',
+ 'dir_z',
+ 'ang_r',
+ 'ang_theta',
+ 'ang_phi',
+ 'field_e',
+ 'field_ex',
+ 'field_ey',
+ 'field_ez',
+ 'field_h',
+ 'field_hx',
+ 'field_hy',
+ 'field_hz',
+ 'flux',
+ 'diff_order_x',
+ 'diff_order_y',
+ 'SimSymbolName'
+ 'SimSymbol'
+ 'float_max'
+ 'float_min'
+ 'int_max'
+ 'int_min'
+ 'mk_interval'
+ 'unicode_superscript',
+]
diff --git a/src/blender_maxwell/utils/sim_symbols/common.py b/src/blender_maxwell/utils/sim_symbols/common.py
new file mode 100644
index 0000000..0f01b61
--- /dev/null
+++ b/src/blender_maxwell/utils/sim_symbols/common.py
@@ -0,0 +1,327 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import typing as typ
+
+import sympy as sp
+
+from blender_maxwell.utils import logger
+from blender_maxwell.utils import sympy_extra as spux
+
+from .name import SimSymbolName
+from .sim_symbol import SimSymbol
+
+log = logger.get(__name__)
+
+
+####################
+# - Common Sim Symbols
+####################
+class CommonSimSymbol(enum.StrEnum):
+ """Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
+
+ The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
+
+ Attributes:
+ Time: A symbol representing a real-valued wavelength.
+ Wavelength: A symbol representing a real-valued wavelength.
+ Implicitly, this symbol often represents "vacuum wavelength" in particular.
+ Wavelength: A symbol representing a real-valued frequency.
+ Generally, this is the non-angular frequency.
+ """
+
+ Index = enum.auto()
+ SimAxisIdx = enum.auto()
+
+ # Space|Time
+ SpaceX = enum.auto()
+ SpaceY = enum.auto()
+ SpaceZ = enum.auto()
+
+ AngR = enum.auto()
+ AngTheta = enum.auto()
+ AngPhi = enum.auto()
+
+ DirX = enum.auto()
+ DirY = enum.auto()
+ DirZ = enum.auto()
+
+ Time = enum.auto()
+
+ # Fields
+ FieldE = enum.auto()
+ FieldH = enum.auto()
+ FieldEx = enum.auto()
+ FieldEy = enum.auto()
+ FieldEz = enum.auto()
+ FieldHx = enum.auto()
+ FieldHy = enum.auto()
+ FieldHz = enum.auto()
+
+ FieldEr = enum.auto()
+ FieldEtheta = enum.auto()
+ FieldEphi = enum.auto()
+ FieldHr = enum.auto()
+ FieldHtheta = enum.auto()
+ FieldHphi = enum.auto()
+
+ # Optics
+ Wavelength = enum.auto()
+ Frequency = enum.auto()
+
+ Flux = enum.auto()
+
+ DiffOrderX = enum.auto()
+ DiffOrderY = enum.auto()
+
+ RelEpsRe = enum.auto()
+ RelEpsIm = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(v: typ.Self) -> str:
+ """Convert the enum value to a human-friendly name.
+
+ Notes:
+ Used to print names in `EnumProperty`s based on this enum.
+
+ Returns:
+ A human-friendly name corresponding to the enum value.
+ """
+ return CommonSimSymbol(v).name
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ """Convert the enum value to a Blender icon.
+
+ Notes:
+ Used to print icons in `EnumProperty`s based on this enum.
+
+ Returns:
+ A human-friendly name corresponding to the enum value.
+ """
+ return ''
+
+ ####################
+ # - Properties
+ ####################
+ @property
+ def name(self) -> str:
+ SSN = SimSymbolName
+ CSS = CommonSimSymbol
+ return {
+ CSS.Index: SSN.LowerI,
+ CSS.SimAxisIdx: SSN.SimAxisIdx,
+ # Space|Time
+ CSS.SpaceX: SSN.LowerX,
+ CSS.SpaceY: SSN.LowerY,
+ CSS.SpaceZ: SSN.LowerZ,
+ CSS.AngR: SSN.LowerR,
+ CSS.AngTheta: SSN.LowerTheta,
+ CSS.AngPhi: SSN.LowerPhi,
+ CSS.DirX: SSN.LowerX,
+ CSS.DirY: SSN.LowerY,
+ CSS.DirZ: SSN.LowerZ,
+ CSS.Time: SSN.LowerT,
+ # Fields
+ CSS.FieldE: SSN.FieldE,
+ CSS.FieldH: SSN.FieldH,
+ CSS.FieldEx: SSN.Ex,
+ CSS.FieldEy: SSN.Ey,
+ CSS.FieldEz: SSN.Ez,
+ CSS.FieldHx: SSN.Hx,
+ CSS.FieldHy: SSN.Hy,
+ CSS.FieldHz: SSN.Hz,
+ CSS.FieldEr: SSN.Er,
+ CSS.FieldHr: SSN.Hr,
+ # Optics
+ CSS.Frequency: SSN.Frequency,
+ CSS.Wavelength: SSN.Wavelength,
+ CSS.Flux: SSN.Flux,
+ CSS.DiffOrderX: SSN.DiffOrderX,
+ CSS.DiffOrderY: SSN.DiffOrderY,
+ CSS.RelEpsRe: SSN.RelEpsRe,
+ CSS.RelEpsIm: SSN.RelEpsIm,
+ }[self]
+
+ def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
+ """Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
+ CSS = CommonSimSymbol
+
+ # Space
+ sym_space = SimSymbol(
+ sym_name=self.name,
+ physical_type=spux.PhysicalType.Length,
+ unit=unit,
+ )
+ sym_ang = SimSymbol(
+ sym_name=self.name,
+ physical_type=spux.PhysicalType.Angle,
+ unit=unit,
+ )
+
+ # Fields
+ def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
+ return SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Complex,
+ physical_type=(
+ spux.PhysicalType.EField if eh == 'e' else spux.PhysicalType.HField
+ ),
+ unit=unit,
+ domain=spux.BlessedSet(
+ sp.ComplexRegion(sp.Interval(0, sp.oo) * sp.Reals)
+ ),
+ )
+
+ return {
+ CSS.Index: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Integer,
+ domain=spux.BlessedSet(sp.Naturals0),
+ ),
+ CSS.SimAxisIdx: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Integer,
+ domain=spux.BlessedSet(sp.FiniteSet(0, 1, 2)),
+ ),
+ # Space|Time
+ CSS.SpaceX: sym_space,
+ CSS.SpaceY: sym_space,
+ CSS.SpaceZ: sym_space,
+ CSS.AngR: sym_space,
+ CSS.AngTheta: sym_ang,
+ CSS.AngPhi: sym_ang,
+ CSS.DirX: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ physical_type=spux.PhysicalType.Length,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval(-sp.oo, sp.oo)),
+ ),
+ CSS.DirY: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ physical_type=spux.PhysicalType.Length,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval(-sp.oo, sp.oo)),
+ ),
+ CSS.Time: SimSymbol(
+ sym_name=self.name,
+ physical_type=spux.PhysicalType.Time,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval(0, sp.oo)),
+ ),
+ # Fields
+ CSS.FieldE: sym_field('e'),
+ CSS.FieldH: sym_field('h'),
+ CSS.FieldEx: sym_field('e'),
+ CSS.FieldEy: sym_field('e'),
+ CSS.FieldEz: sym_field('e'),
+ CSS.FieldHx: sym_field('h'),
+ CSS.FieldHy: sym_field('h'),
+ CSS.FieldHz: sym_field('h'),
+ CSS.FieldEr: sym_field('e'),
+ CSS.FieldEtheta: sym_field('e'),
+ CSS.FieldEphi: sym_field('e'),
+ CSS.FieldHr: sym_field('h'),
+ CSS.FieldHtheta: sym_field('h'),
+ CSS.FieldHphi: sym_field('h'),
+ # Optics
+ CSS.Wavelength: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ physical_type=spux.PhysicalType.Length,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval.open(0, sp.oo)),
+ ),
+ CSS.Frequency: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ physical_type=spux.PhysicalType.Freq,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval.open(0, sp.oo)),
+ ),
+ CSS.Flux: SimSymbol(
+ sym_name=SimSymbolName.Flux,
+ mathtype=spux.MathType.Real,
+ physical_type=spux.PhysicalType.Power,
+ unit=unit,
+ domain=spux.BlessedSet(sp.Interval.open(0, sp.oo)),
+ ),
+ CSS.DiffOrderX: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Integer,
+ domain=spux.BlessedSet(sp.Integers),
+ ),
+ CSS.DiffOrderY: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Integer,
+ domain=spux.BlessedSet(sp.Integers),
+ ),
+ CSS.RelEpsRe: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ domain=spux.BlessedSet(sp.Reals),
+ ),
+ CSS.RelEpsIm: SimSymbol(
+ sym_name=self.name,
+ mathtype=spux.MathType.Real,
+ domain=spux.BlessedSet(sp.Reals),
+ ),
+ }[self]
+
+
+####################
+# - Selected Direct-Access to SimSymbols
+####################
+idx = CommonSimSymbol.Index.sim_symbol
+sim_axis_idx = CommonSimSymbol.SimAxisIdx.sim_symbol
+t = CommonSimSymbol.Time.sim_symbol
+wl = CommonSimSymbol.Wavelength.sim_symbol
+freq = CommonSimSymbol.Frequency.sim_symbol
+
+space_x = CommonSimSymbol.SpaceX.sim_symbol
+space_y = CommonSimSymbol.SpaceY.sim_symbol
+space_z = CommonSimSymbol.SpaceZ.sim_symbol
+
+dir_x = CommonSimSymbol.DirX.sim_symbol
+dir_y = CommonSimSymbol.DirY.sim_symbol
+dir_z = CommonSimSymbol.DirZ.sim_symbol
+
+ang_r = CommonSimSymbol.AngR.sim_symbol
+ang_theta = CommonSimSymbol.AngTheta.sim_symbol
+ang_phi = CommonSimSymbol.AngPhi.sim_symbol
+
+field_e = CommonSimSymbol.FieldE.sim_symbol
+field_h = CommonSimSymbol.FieldH.sim_symbol
+field_ex = CommonSimSymbol.FieldEx.sim_symbol
+field_ey = CommonSimSymbol.FieldEy.sim_symbol
+field_ez = CommonSimSymbol.FieldEz.sim_symbol
+field_hx = CommonSimSymbol.FieldHx.sim_symbol
+field_hy = CommonSimSymbol.FieldHx.sim_symbol
+field_hz = CommonSimSymbol.FieldHx.sim_symbol
+
+flux = CommonSimSymbol.Flux.sim_symbol
+
+diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
+diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol
+
+rel_eps_re = CommonSimSymbol.RelEpsRe.sim_symbol
+rel_eps_im = CommonSimSymbol.RelEpsIm.sim_symbol
diff --git a/src/blender_maxwell/utils/sim_symbols/name.py b/src/blender_maxwell/utils/sim_symbols/name.py
new file mode 100644
index 0000000..7cc45c4
--- /dev/null
+++ b/src/blender_maxwell/utils/sim_symbols/name.py
@@ -0,0 +1,208 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import enum
+import string
+import typing as typ
+
+from blender_maxwell.utils import logger
+
+log = logger.get(__name__)
+
+####################
+# - Simulation Symbol Names
+####################
+_l = ''
+_it_lower = iter(string.ascii_lowercase)
+
+
+class SimSymbolName(enum.StrEnum):
+ # Generic
+ Constant = enum.auto()
+ Expr = enum.auto()
+ Data = enum.auto()
+
+ # Ascii Letters
+ while True:
+ try:
+ globals()['_l'] = next(globals()['_it_lower'])
+ except StopIteration:
+ break
+
+ locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
+ locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
+
+ # Greek Letters
+ LowerTheta = enum.auto()
+ LowerPhi = enum.auto()
+
+ # EM Fields
+ FieldE = enum.auto()
+ FieldH = enum.auto()
+ Ex = enum.auto()
+ Ey = enum.auto()
+ Ez = enum.auto()
+ Hx = enum.auto()
+ Hy = enum.auto()
+ Hz = enum.auto()
+
+ Er = enum.auto()
+ Etheta = enum.auto()
+ Ephi = enum.auto()
+ Hr = enum.auto()
+ Htheta = enum.auto()
+ Hphi = enum.auto()
+
+ # Optics
+ Wavelength = enum.auto()
+ Frequency = enum.auto()
+
+ Perm = enum.auto()
+ PermXX = enum.auto()
+ PermYY = enum.auto()
+ PermZZ = enum.auto()
+
+ Flux = enum.auto()
+
+ DiffOrderX = enum.auto()
+ DiffOrderY = enum.auto()
+
+ BlochX = enum.auto()
+ BlochY = enum.auto()
+ BlochZ = enum.auto()
+
+ # New Backwards Compatible Entries
+ ## -> Ordered lists carry a particular enum integer index.
+ ## -> Therefore, anything but adding an index breaks backwards compat.
+ ## -> ...With all previous files.
+ ConstantRange = enum.auto()
+ Count = enum.auto()
+
+ RelEpsRe = enum.auto()
+ RelEpsIm = enum.auto()
+
+ SimAxisIdx = enum.auto()
+
+ ####################
+ # - UI
+ ####################
+ @staticmethod
+ def to_name(v: typ.Self) -> str:
+ """Convert the enum value to a human-friendly name.
+
+ Notes:
+ Used to print names in `EnumProperty`s based on this enum.
+
+ Returns:
+ A human-friendly name corresponding to the enum value.
+ """
+ return SimSymbolName(v).name
+
+ @staticmethod
+ def to_icon(_: typ.Self) -> str:
+ """Convert the enum value to a Blender icon.
+
+ Notes:
+ Used to print icons in `EnumProperty`s based on this enum.
+
+ Returns:
+ A human-friendly name corresponding to the enum value.
+ """
+ return ''
+
+ ####################
+ # - Computed Properties
+ ####################
+ @property
+ def name(self) -> str:
+ SSN = SimSymbolName
+ return (
+ # Ascii Letters
+ {SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
+ | {
+ SSN[f'Upper{letter.upper()}']: letter.upper()
+ for letter in string.ascii_lowercase
+ }
+ | {
+ # Generic
+ SSN.Constant: 'cst',
+ SSN.ConstantRange: 'cst_range',
+ SSN.Expr: 'expr',
+ SSN.Data: 'data',
+ SSN.Count: 'count',
+ # Greek Letters
+ SSN.LowerTheta: 'theta',
+ SSN.LowerPhi: 'phi',
+ # Fields
+ SSN.FieldE: 'E*',
+ SSN.FieldH: 'H*',
+ SSN.Ex: 'Ex',
+ SSN.Ey: 'Ey',
+ SSN.Ez: 'Ez',
+ SSN.Hx: 'Hx',
+ SSN.Hy: 'Hy',
+ SSN.Hz: 'Hz',
+ SSN.Er: 'Er',
+ SSN.Etheta: 'Ey',
+ SSN.Ephi: 'Ez',
+ SSN.Hr: 'Hx',
+ SSN.Htheta: 'Hy',
+ SSN.Hphi: 'Hz',
+ # Optics
+ SSN.Wavelength: 'wl',
+ SSN.Frequency: 'freq',
+ SSN.Perm: 'eps_r',
+ SSN.PermXX: 'eps_xx',
+ SSN.PermYY: 'eps_yy',
+ SSN.PermZZ: 'eps_zz',
+ SSN.Flux: 'flux',
+ SSN.DiffOrderX: 'order_x',
+ SSN.DiffOrderY: 'order_y',
+ SSN.BlochX: 'bloch_x',
+ SSN.BlochY: 'bloch_y',
+ SSN.BlochZ: 'bloch_z',
+ SSN.RelEpsRe: 'eps_r_re',
+ SSN.RelEpsIm: 'eps_r_im',
+ SSN.SimAxisIdx: '[xyz]',
+ }
+ )[self]
+
+ @property
+ def name_pretty(self) -> str:
+ SSN = SimSymbolName
+ return {
+ # Generic
+ SSN.Count: '#',
+ # Greek Letters
+ SSN.LowerTheta: 'θ',
+ SSN.LowerPhi: 'φ',
+ # Fields
+ SSN.Er: 'Er',
+ SSN.Etheta: 'Eθ',
+ SSN.Ephi: 'Eφ',
+ SSN.Hr: 'Hr',
+ SSN.Htheta: 'Hθ',
+ SSN.Hphi: 'Hφ',
+ # Optics
+ SSN.Wavelength: 'λ',
+ SSN.Frequency: 'fᵣ',
+ SSN.Perm: 'εᵣ',
+ SSN.PermXX: 'εᵣ[xx]',
+ SSN.PermYY: 'εᵣ[yy]',
+ SSN.PermZZ: 'εᵣ[zz]',
+ SSN.RelEpsRe: 'ℝ[εᵣ]',
+ SSN.RelEpsIm: '𝕀[εᵣ]',
+ }.get(self, self.name)
diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols/sim_symbol.py
similarity index 56%
rename from src/blender_maxwell/utils/sim_symbols.py
rename to src/blender_maxwell/utils/sim_symbols/sim_symbol.py
index f663d36..a3f0ba9 100644
--- a/src/blender_maxwell/utils/sim_symbols.py
+++ b/src/blender_maxwell/utils/sim_symbols/sim_symbol.py
@@ -14,217 +14,32 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import enum
-import functools
-import string
-import sys
-import typing as typ
-from fractions import Fraction
+"""Implements `SimSymbol`, a convenient representation of a symbolic variable suiteable for use when describing various mathematical and numerical interfaces."""
+import functools
+import random
+import typing as typ
+
+import jax
+import jax.numpy as jnp
import jaxtyping as jtyp
+import numpy as np
import pydantic as pyd
import sympy as sp
+import sympy.stats as sps
+from sympy.tensor.array.expressions import ArraySymbol
-from . import logger, serialize
-from . import sympy_extra as spux
+from blender_maxwell.utils import logger, serialize
+from blender_maxwell.utils import sympy_extra as spux
+from blender_maxwell.utils.frozendict import frozendict
+from blender_maxwell.utils.lru_method import method_lru
-int_min = -(2**64)
-int_max = 2**64
-float_min = sys.float_info.min
-float_max = sys.float_info.max
+from .name import SimSymbolName
+from .utils import unicode_superscript
log = logger.get(__name__)
-
-def unicode_superscript(n: int) -> str:
- """Transform an integer into its unicode-based superscript character."""
- return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
-
-
-####################
-# - Simulation Symbol Names
-####################
-_l = ''
-_it_lower = iter(string.ascii_lowercase)
-
-
-class SimSymbolName(enum.StrEnum):
- # Generic
- Constant = enum.auto()
- Expr = enum.auto()
- Data = enum.auto()
-
- # Ascii Letters
- while True:
- try:
- globals()['_l'] = next(globals()['_it_lower'])
- except StopIteration:
- break
-
- locals()[f'Lower{globals()["_l"].upper()}'] = enum.auto()
- locals()[f'Upper{globals()["_l"].upper()}'] = enum.auto()
-
- # Greek Letters
- LowerTheta = enum.auto()
- LowerPhi = enum.auto()
-
- # EM Fields
- Ex = enum.auto()
- Ey = enum.auto()
- Ez = enum.auto()
- Hx = enum.auto()
- Hy = enum.auto()
- Hz = enum.auto()
-
- Er = enum.auto()
- Etheta = enum.auto()
- Ephi = enum.auto()
- Hr = enum.auto()
- Htheta = enum.auto()
- Hphi = enum.auto()
-
- # Optics
- Wavelength = enum.auto()
- Frequency = enum.auto()
-
- Perm = enum.auto()
- PermXX = enum.auto()
- PermYY = enum.auto()
- PermZZ = enum.auto()
-
- Flux = enum.auto()
-
- DiffOrderX = enum.auto()
- DiffOrderY = enum.auto()
-
- BlochX = enum.auto()
- BlochY = enum.auto()
- BlochZ = enum.auto()
-
- # New Backwards Compatible Entries
- ## -> Ordered lists carry a particular enum integer index.
- ## -> Therefore, anything but adding an index breaks backwards compat.
- ## -> ...With all previous files.
- ConstantRange = enum.auto()
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(v: typ.Self) -> str:
- """Convert the enum value to a human-friendly name.
-
- Notes:
- Used to print names in `EnumProperty`s based on this enum.
-
- Returns:
- A human-friendly name corresponding to the enum value.
- """
- return SimSymbolName(v).name
-
- @staticmethod
- def to_icon(_: typ.Self) -> str:
- """Convert the enum value to a Blender icon.
-
- Notes:
- Used to print icons in `EnumProperty`s based on this enum.
-
- Returns:
- A human-friendly name corresponding to the enum value.
- """
- return ''
-
- ####################
- # - Computed Properties
- ####################
- @property
- def name(self) -> str:
- SSN = SimSymbolName
- return (
- # Ascii Letters
- {SSN[f'Lower{letter.upper()}']: letter for letter in string.ascii_lowercase}
- | {
- SSN[f'Upper{letter.upper()}']: letter.upper()
- for letter in string.ascii_lowercase
- }
- | {
- # Generic
- SSN.Constant: 'cst',
- SSN.ConstantRange: 'cst_range',
- SSN.Expr: 'expr',
- SSN.Data: 'data',
- # Greek Letters
- SSN.LowerTheta: 'theta',
- SSN.LowerPhi: 'phi',
- # Fields
- SSN.Ex: 'Ex',
- SSN.Ey: 'Ey',
- SSN.Ez: 'Ez',
- SSN.Hx: 'Hx',
- SSN.Hy: 'Hy',
- SSN.Hz: 'Hz',
- SSN.Er: 'Ex',
- SSN.Etheta: 'Ey',
- SSN.Ephi: 'Ez',
- SSN.Hr: 'Hx',
- SSN.Htheta: 'Hy',
- SSN.Hphi: 'Hz',
- # Optics
- SSN.Wavelength: 'wl',
- SSN.Frequency: 'freq',
- SSN.Perm: 'eps_r',
- SSN.PermXX: 'eps_xx',
- SSN.PermYY: 'eps_yy',
- SSN.PermZZ: 'eps_zz',
- SSN.Flux: 'flux',
- SSN.DiffOrderX: 'order_x',
- SSN.DiffOrderY: 'order_y',
- SSN.BlochX: 'bloch_x',
- SSN.BlochY: 'bloch_y',
- SSN.BlochZ: 'bloch_z',
- }
- )[self]
-
- @property
- def name_pretty(self) -> str:
- SSN = SimSymbolName
- return {
- # Generic
- # Greek Letters
- SSN.LowerTheta: 'θ',
- SSN.LowerPhi: 'φ',
- # Fields
- SSN.Er: 'Er',
- SSN.Etheta: 'Eθ',
- SSN.Ephi: 'Eφ',
- SSN.Hr: 'Hr',
- SSN.Htheta: 'Hθ',
- SSN.Hphi: 'Hφ',
- # Optics
- SSN.Wavelength: 'λ',
- SSN.Frequency: 'fᵣ',
- SSN.Perm: 'εᵣ',
- SSN.PermXX: 'εᵣ[xx]',
- SSN.PermYY: 'εᵣ[yy]',
- SSN.PermZZ: 'εᵣ[zz]',
- }.get(self, self.name)
-
-
-####################
-# - Simulation Symbol
-####################
-def mk_interval(
- interval_finite: tuple[int | Fraction | float, int | Fraction | float],
- interval_inf: tuple[bool, bool],
- interval_closed: tuple[bool, bool],
-) -> sp.Interval:
- """Create a symbolic interval from the tuples (and unit) defining it."""
- return sp.Interval(
- start=(interval_finite[0] if not interval_inf[0] else -sp.oo),
- end=(interval_finite[1] if not interval_inf[1] else sp.oo),
- left_open=(True if interval_inf[0] else not interval_closed[0]),
- right_open=(True if interval_inf[1] else not interval_closed[1]),
- )
+MT = spux.MathType
class SimSymbol(pyd.BaseModel):
@@ -258,93 +73,205 @@ class SimSymbol(pyd.BaseModel):
model_config = pyd.ConfigDict(frozen=True)
+ # Name | Type
sym_name: SimSymbolName
- mathtype: spux.MathType = spux.MathType.Real
+ mathtype: MT = MT.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
# Units
## -> 'None' indicates that no particular unit has yet been chosen.
- ## -> When 'self.physical_type' is NonPhysical, can only be None.
+ ## -> When 'self.physical_type' is NonPhysical, _no unit_ can be chosen.
unit: spux.Unit | None = None
+ ####################
+ # - Dimensionality
+ ####################
# Size
- ## -> All SimSymbol sizes are "2D", but interpreted by convention.
- ## -> 1x1: "Scalar".
- ## -> nx1: "Vector".
- ## -> 1xn: "Covector".
- ## -> nxn: "Matrix".
+ ## -> 1*1: "Scalar".
+ ## -> n*1: "Vector".
+ ## -> 1*n: "Covector".
+ ## -> n*m: "Matrix".
+ ## -> n*m*...: "Tensor".
rows: int = 1
cols: int = 1
-
- # Valid Domain
- ## -> Declares the valid set of values that may be given to this symbol.
- ## -> By convention, units are not encoded in the domain sp.Set.
- ## -> 'sp.Set's are extremely expressive and cool.
- domain: spux.SympyExpr | None = None
+ depths: tuple[int, ...] = ()
@functools.cached_property
- def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet:
- if self.rows > 1 or self.cols > 1:
- return sp.matrices.MatrixSet(self.rows, self.cols, self.domain)
- return self.domain
+ def is_scalar(self) -> bool:
+ """Whether the symbol represents a scalar."""
+ return self.rows == self.cols == 1 and self.depths == ()
- preview_value: spux.SympyExpr | None = None
+ @functools.cached_property
+ def is_vector(self) -> bool:
+ """Whether the symbol represents a (column) vector."""
+ return self.rows > 1 and self.cols == 1 and self.depths == ()
- ####################
- # - Validators
- ####################
- ## TODO: Check domain against MathType
- ## -- Surprisingly hard without a lot of special-casing.
+ @functools.cached_property
+ def is_covector(self) -> bool:
+ """Whether the symbol represents a covector."""
+ return self.rows == 1 and self.cols > 1 and self.depths == ()
- ## TODO: Check that size is valid for the PhysicalType.
+ @functools.cached_property
+ def is_matrix(self) -> bool:
+ """Whether the symbol represents a matrix, which isn't better described as a scalar/vector/covector."""
+ return self.rows > 1 and self.cols > 1 and self.depths == ()
- ## TODO: Check that constant value (domain=FiniteSet(cst)) is compatible with the MathType.
-
- ## TODO: Check that preview_value is in the domain.
-
- @pyd.model_validator(mode='after')
- def set_undefined_domain_from_mathtype(self) -> typ.Self:
- """When the domain is not set, then set it using the symbolic set of the MathType."""
- if self.domain is None:
- object.__setattr__(self, 'domain', self.mathtype.symbolic_set)
- return self
-
- @pyd.model_validator(mode='after')
- def conform_undefined_preview_value_to_constant(self) -> typ.Self:
- """When the `SimSymbol` is a constant, but the preview value is not set, then set the preview value from the constant."""
- if self.is_constant and not self.preview_value:
- object.__setattr__(self, 'preview_value', self.constant_value)
- return self
-
- @pyd.model_validator(mode='after')
- def conform_preview_value(self) -> typ.Self:
- """Conform the given preview value to the `SimSymbol`."""
- if self.is_constant and not self.preview_value:
- object.__setattr__(
- self,
- 'preview_value',
- self.conform(self.preview_value, strip_units=True),
- )
- return self
+ @functools.cached_property
+ def is_ndim(self) -> bool:
+ """Whether the symbol represents an n-dimensional tensor, which isn't better described as a scalar/vector/covector/matrix."""
+ return self.depths != ()
####################
# - Domain
####################
- @functools.cached_property
- def is_constant(self) -> bool:
- """When the symbol domain is a single-element `sp.FiniteSet`, then the symbol can be considered to be a constant."""
- return isinstance(self.domain, sp.FiniteSet) and len(self.domain) == 1
+ # Representation of Valid Symbolic Domain
+ ## -> Declares the valid set of values that may be given to this symbol.
+ ## -> By convention, units are not encoded in the domain sp.Set.
+ ## -> 'sp.Set's are extremely expressive and cool.
+ domain: spux.BlessedSet | None = None
@functools.cached_property
- def constant_value(self) -> bool:
- """Get the constant when `is_constant` is True.
+ def domain_mat(self) -> sp.Set | sp.matrices.MatrixSet:
+ """Get the domain as a `MatrixSet`, if the symbol represents a vector/covector/matrix, otherwise .
- The `self.unit_factor` is multiplied onto the constant at this point.
+ Raises:
+ ValueError: If the symbol is an arbitrary n-dimensional tensor.
"""
- if self.is_constant:
- return next(iter(self.domain)) * self.unit_factor
+ if self.is_scalar:
+ return self.domain
+ if self.is_vector or self.is_covector or self.is_matrix:
+ return self.domain.bset_mat(self.rows, self.cols)
- msg = 'Tried to get constant value of non-constant SimSymbol.'
+ msg = f"Can't determine set representation of arbitrary n-dimensional tensor (SimSymbol = {self})"
+ raise ValueError(msg)
+
+ ####################
+ # - Stochastic Variables
+ ####################
+ # Stochastic Sample Space | PDF
+ ## -> When stoch_var is set, this variable should be considered stochastic.
+ ## -> When stochastic, the SimSymbol must be real | unitless | [0,1] dm.
+ stoch_var: spux.SympyExpr | None = None
+ stoch_seed: int = 0
+
+ @functools.cached_property
+ def stoch_key_jax(self) -> jax._src.prng.PRNGKeyArray:
+ """The key guaranteeing a deterministic random sample based on `self.stoch_seed`."""
+ return jax.random.key(self.stock_seed)
+
+ @functools.cached_property
+ def sample_space(self) -> sp.Set | None:
+ """The space of valid inputs to the PDF."""
+ return self.stoch_var.pspace.set
+
+ @functools.cached_property
+ def pdf(self) -> spux.SympyExpr | None:
+ """The expression of probability density function.
+
+ When `self.sample_space` is `None`, then this too returns `None`, since the symbol is not stochastic (and therefore has no "measurable space").
+ """
+ if self.stoch_var is not None:
+ return self.stoch_var.pspace.pdf
+ return None
+
+ @functools.cached_property
+ def cdf(self) -> spux.SympyExpr | None:
+ """The expression of cumulative distribution function."""
+ if self.stoch_var is not None:
+ return sps.cdf(self.stoch_var)
+ return None
+
+ @functools.cached_property
+ def pdf_jax(self) -> typ.Callable[[jax.Array], jax.Array] | None:
+ """The stochastic PDF as a `jax`-registered function."""
+ if self.stoch_var is not None:
+ return sp.lambdify(self.stoch_var, self.pdf, 'jax')
+ return None
+
+ @method_lru()
+ def sample_np(self, repeat: int = 1) -> typ.Callable[[jax.Array], jax.Array] | None:
+ """Sample the stochastic variable `repeat` times, using 'numpy'."""
+ if self.stoch_var is not None:
+ sample_shape = (*self.shape, repeat)
+ return sps.sample(self.stoch_var, size=sample_shape, library='numpy')
+ return None
+
+ @method_lru()
+ def sample_jax(
+ self, repeat: int = 1
+ ) -> typ.Callable[[jax.Array], jax.Array] | None:
+ """Sample the stochastic variable `repeat` times, using 'jax'.
+
+ For now, this cannot be used in `@jit`, since it merely converts the output of `sample_np`.
+ In the long run, we'll need a way of directly sampling `sympy` stochastic variables with `jax` functions
+ """
+ if self.stoch_var is not None:
+ return jnp.array(self.sample_np(repeat))
+ return None
+
+ # TODO: 'Ecosystem'!
+ ## -- Ideally we'd expand sympy.stats.sample to support 'jax' directly.
+ ## -- 'None' dims in InfoFlow would mean either cont. or stochastic.
+ ## -- Stochastic dims would be realized w/integer n argument.
+ ## -- I suppose a dedicated 'stochastic symbol' node would be warranted.
+ ## -- With a nice dropdown for cool distributions and cool jazz.
+
+ ####################
+ # - Validators: Stochastic Variable
+ ####################
+ @pyd.model_validator(mode='after')
+ def randomize_stochvar_key(self) -> typ.Self:
+ """Select a random integer value for the stochastic seed.
+
+ Repeated calls to `self.sample_jax()`
+ """
+ if self.stoch_var is not None:
+ self.stoch_seed = random.randint(0, 2**32)
+ return self
+
+ @pyd.model_validator(mode='after')
+ def set_stochvar_output_space_to_domain(self) -> typ.Self:
+ """When the symbol is stochastic, set `self.domain` to the range of the stochastic variable's probability density function.
+
+ All PDFs are real, unitless values defined on the closed interval `[0,1]`.
+ Therefore, the symbol itself must conform to these preconditions.
+ """
+ if self.stoch_var is not None:
+ if (
+ self.domain is None
+ and self.physical_type is spux.PhysicalType.NonPhysical
+ and self.unit is None
+ ):
+ object.__setattr__(self, 'domain', spux.BlessedSet(sp.Interval(0, 1)))
+ else:
+ msg = 'Stochastic variables must be unitless'
+ raise ValueError(msg)
+ return self
+
+ ####################
+ # - Validators: Set Domain
+ ####################
+ @pyd.model_validator(mode='after')
+ def set_undefined_domain_from_mathtype(self) -> typ.Self:
+ """When the domain is not set, then set it using the symbolic set of the MathType."""
+ if self.domain is None:
+ object.__setattr__(
+ self, 'domain', spux.BlessedSet(self.mathtype.symbolic_set)
+ )
+ return self
+
+ ####################
+ # - Validators: Asserters
+ ####################
+ @pyd.model_validator(mode='after')
+ def assert_domain_in_mathtype_set(self) -> typ.Self:
+ """Verify that the domain is a (non-strict) subset of the MathType."""
+ if self.domain.bset is sp.EmptySet or self.domain.bset.issubset(
+ self.mathtype.symbolic_set
+ ):
+ return self
+
+ msg = f'Domain {self.domain} is not in the mathtype {self.mathtype}'
raise ValueError(msg)
@functools.cached_property
@@ -422,12 +349,14 @@ class SimSymbol(pyd.BaseModel):
@functools.cached_property
def size(self) -> spux.NumberSize1D | None:
"""The 1D number size of this `SimSymbol`, if it has one; else None."""
- return {
- (1, 1): spux.NumberSize1D.Scalar,
- (2, 1): spux.NumberSize1D.Vec2,
- (3, 1): spux.NumberSize1D.Vec3,
- (4, 1): spux.NumberSize1D.Vec4,
- }.get((self.rows, self.cols))
+ if self.depths == ():
+ return {
+ (1, 1): spux.NumberSize1D.Scalar,
+ (2, 1): spux.NumberSize1D.Vec2,
+ (3, 1): spux.NumberSize1D.Vec3,
+ (4, 1): spux.NumberSize1D.Vec4,
+ }.get((self.rows, self.cols))
+ return None
@functools.cached_property
def shape(self) -> tuple[int, ...]:
@@ -437,13 +366,16 @@ class SimSymbol(pyd.BaseModel):
Is never `None`; instead, empty tuple `()` is used.
"""
- match (self.rows, self.cols):
- case (1, 1):
- return ()
- case (_, 1):
- return (self.rows,)
- case (_, _):
- return (self.rows, self.cols)
+ if self.depths == ():
+ match (self.rows, self.cols):
+ case (1, 1):
+ return ()
+ case (_, 1):
+ return (self.rows,)
+ case (_, _):
+ return (self.rows, self.cols)
+
+ return (self.rows, self.cols, *self.depths)
@functools.cached_property
def shape_len(self) -> spux.SympyExpr:
@@ -471,13 +403,13 @@ class SimSymbol(pyd.BaseModel):
# MathType Assumption
mathtype_kwargs = {}
match self.mathtype:
- case spux.MathType.Integer:
+ case MT.Integer:
mathtype_kwargs |= {'integer': True}
- case spux.MathType.Rational:
+ case MT.Rational:
mathtype_kwargs |= {'rational': True}
- case spux.MathType.Real:
+ case MT.Real:
mathtype_kwargs |= {'real': True}
- case spux.MathType.Complex:
+ case MT.Complex:
mathtype_kwargs |= {'complex': True}
# Non-Zero Assumption
@@ -485,28 +417,38 @@ class SimSymbol(pyd.BaseModel):
mathtype_kwargs |= {'nonzero': True}
# Positive/Negative Assumption
- if self.mathtype is not spux.MathType.Complex:
- if self.domain.inf >= 0:
+ if self.mathtype is not MT.Complex:
+ has_pos = self.domain & sp.Interval.open(0, sp.oo) is not sp.EmptySet
+ has_neg = self.domain & sp.Interval.open(0, sp.oo) is not sp.EmptySet
+ if has_pos and not has_neg:
mathtype_kwargs |= {'positive': True}
- elif self.domain.sup < 0:
+ if has_neg and not has_pos:
mathtype_kwargs |= {'negative': True}
# Scalar: Return Symbol
- if self.rows == 1 and self.cols == 1:
+ if self.is_scalar:
return sp.Symbol(self.sym_name.name, **mathtype_kwargs)
# Vector|Matrix: Return Matrix of Symbols
## -> MatrixSymbol doesn't support assumptions.
## -> This little construction does.
- return sp.ImmutableMatrix(
- [
+ if not self.is_ndim:
+ return sp.ImmutableMatrix(
[
- sp.Symbol(self.sym_name.name + f'_{row}{col}', **mathtype_kwargs)
- for col in range(self.cols)
+ [
+ sp.Symbol(
+ self.sym_name.name + f'_{row}{col}', **mathtype_kwargs
+ )
+ for col in range(self.cols)
+ ]
+ for row in range(self.rows)
]
- for row in range(self.rows)
- ]
- )
+ )
+
+ # Arbitrary Tensor: Just Return Symbol
+ ## -> Maybe we'll do the other stuff later to keep assumptions.
+ ## -> Maybe we'll retire matrix-of-symbol entirely instead.
+ return self.sp_symbol_matsym
@functools.cached_property
def sp_symbol_matsym(self) -> sp.Symbol | sp.MatrixSymbol:
@@ -525,7 +467,9 @@ class SimSymbol(pyd.BaseModel):
"""
if self.shape_len == 0:
return self.sp_symbol
- return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
+ if self.depths == ():
+ return sp.MatrixSymbol(self.sym_name.name, self.rows, self.cols)
+ return ArraySymbol(self.sym_name.name, self.shape)
@functools.cached_property
def sp_symbol_phy(self) -> spux.SympyExpr:
@@ -545,7 +489,9 @@ class SimSymbol(pyd.BaseModel):
Since `ExprSocketDef` allows the use of infinite bounds for `default_min` and `default_max`, we defer the decision of how to treat finite-fallback to the `ExprSocketDef`.
"""
if self.size is not None:
- if self.unit in self.physical_type.valid_units:
+ if self.unit in self.physical_type.valid_units or (
+ self.unit is None and self.physical_type is None
+ ):
socket_info = {
'output_name': self.sym_name,
# Socket Interface
@@ -555,27 +501,29 @@ class SimSymbol(pyd.BaseModel):
# Defaults: Units
'default_unit': self.unit,
'default_symbols': [],
+ 'exclude_zero': 0 not in self.domain,
+ # Domain Info
+ 'abs_min': self.domain.inf,
+ 'abs_max': self.domain.sup,
+ 'abs_min_closed': self.domain.min_closed,
+ 'abs_max_closed': self.domain.max_closed,
}
- # Defaults: FlowKind.Value
- if self.preview_value:
+ # Complex Domain: Closure of Imaginary Axis
+ if self.mathtype is MT.Complex:
socket_info |= {
- 'default_value': self.conform(
- self.preview_value, strip_unit=True
- )
+ 'abs_min_closed_im': self.domain.min_closed_im,
+ 'abs_max_closed_im': self.domain.min_closed_im,
}
- # Defaults: FlowKind.Range
- if (
- self.mathtype is not spux.MathType.Complex
- and self.rows == 1
- and self.cols == 1
- ):
+ # FlowKind.Range: Min/Max
+ if self.mathtype is not MT.Complex and self.shape_len == 0:
socket_info |= {
'default_min': self.domain.inf,
'default_max': self.domain.sup,
}
- ## TODO: Handle discontinuities / disjointness / open boundaries.
+
+ return socket_info
msg = f'Tried to generate an ExprSocket from a SymSymbol "{self.name}", but its unit ({self.unit}) is not a valid unit of its physical type ({self.physical_type}) (SimSymbol={self})'
raise NotImplementedError(msg)
@@ -586,8 +534,8 @@ class SimSymbol(pyd.BaseModel):
####################
# - Operations: Raw Update
####################
- def update(self, **kwargs) -> typ.Self:
- """Create a new `SimSymbol`, such that the given keyword arguments override the existing values."""
+ @method_lru()
+ def _update(self, kwargs: frozendict) -> typ.Self:
if not kwargs:
return self
@@ -607,9 +555,29 @@ class SimSymbol(pyd.BaseModel):
domain=get_attr('domain'),
)
+ def update(self, **kwargs) -> typ.Self:
+ """Create a new `SimSymbol`, such that the given keyword arguments override the existing values."""
+ return self._update(frozendict(kwargs))
+
+ @method_lru(maxsize=512)
+ def scale_to_unit_system(self, unit_system: spux.UnitSystem | None) -> typ.Self:
+ """Compute the SimSymbol resulting from the unit system conversion."""
+ if self.unit is not None:
+ new_unit = spux.convert_to_unit_system(self.unit, unit_system)
+
+ scaling_factor = spux.convert_to_unit_system(
+ self.unit, unit_system, strip_units=True
+ )
+ return self.update(
+ unit=new_unit,
+ domain=self.domain * scaling_factor,
+ )
+ return self
+
####################
# - Operations: Comparison
####################
+ @method_lru(maxsize=256)
def compare(self, other: typ.Self) -> typ.Self:
"""Whether this SimSymbol can be considered equivalent to another, and thus universally usable in arbitrary mathematical operations together.
@@ -626,36 +594,44 @@ class SimSymbol(pyd.BaseModel):
and self.domain == other.domain
)
+ @method_lru(maxsize=256)
def compare_size(self, other: typ.Self) -> typ.Self:
"""Compare the size of this `SimSymbol` with another."""
- return self.rows == other.rows and self.cols == other.cols
+ return (
+ self.rows == other.rows
+ and self.cols == other.cols
+ and self.depths == other.depths
+ )
+ @method_lru(maxsize=256)
def compare_addable(
self, other: typ.Self, allow_differing_unit: bool = False
) -> bool:
"""Whether two `SimSymbol`s can be added."""
common = (
- self.compare_size(other.output)
+ self.compare_size(other)
and self.physical_type is other.physical_type
and not (
- self.physical_type is spux.NonPhysical
+ self.physical_type is spux.PhysicalType.NonPhysical
and self.unit is not None
and self.unit != other.unit
)
and not (
- other.physical_type is spux.NonPhysical
+ other.physical_type is spux.PhysicalType.NonPhysical
and other.unit is not None
and self.unit != other.unit
)
)
if not allow_differing_unit:
- return common and self.output.unit == other.output.unit
+ return common and self.unit == other.unit
return common
+ @method_lru(maxsize=256)
def compare_multiplicable(self, other: typ.Self) -> bool:
"""Whether two `SimSymbol`s can be multiplied."""
- return self.shape_len == 0 or self.compare_size(other)
+ return self.shape_len == 0 or other.shape_len == 0 or self.compare_size(other)
+ @method_lru(maxsize=256)
def compare_exponentiable(self, other: typ.Self) -> bool:
"""Whether two `SimSymbol`s can be exponentiated.
@@ -672,35 +648,11 @@ class SimSymbol(pyd.BaseModel):
other.physical_type is spux.PhysicalType.NonPhysical and other.unit is None
)
- ####################
- # - Operations: Copying Setters
- ####################
- def set_constant(self, constant_value: spux.SympyType) -> typ.Self:
- """Set the constant value of this `SimSymbol`, by setting it as the only value in a `sp.FiniteSet` domain.
-
- The `constant_value` will be conformed and stripped (with `self.conform()`) before being injected into the new `sp.FiniteSet` domain.
-
- Warnings:
- Keep in mind that domains do not encode units, for practical reasons related to the diverging ways in which various `sp.Set` subclasses interpret units.
-
- This isn't noticeable in normal constant-symbol workflows, where the constant is retrieved using `self.constant_value` (which adds `self.unit_factor`).
- However, **remember that retrieving the domain directly won't add the unit**.
-
- Ye been warned!
- """
- if self.is_constant:
- return self.update(
- domain=sp.FiniteSet(self.conform(constant_value, strip_unit=True))
- )
-
- msg = 'Tried to set constant value of non-constant SimSymbol.'
- raise ValueError(msg)
-
####################
# - Operations: Conforming Mappers
####################
def conform(
- self, sp_obj: spux.SympyType, strip_unit: bool = False
+ self, obj: np.ndarray | jax.Array | spux.SympyType, strip_unit: bool = False
) -> spux.SympyType:
"""Conform a sympy object to the properties of this `SimSymbol`, if possible.
@@ -715,30 +667,41 @@ class SimSymbol(pyd.BaseModel):
Raises:
ValueError: If the units of `sp_obj` can't be cleanly converted to `self.unit`.
"""
- res = sp_obj
+ if isinstance(obj, np.ndarray | jax.Array):
+ if obj.shape == (): # noqa: SIM108
+ res = sp.S(obj.item(0))
+ else:
+ res = sp.S(np.array(obj))
+ else:
+ res = sp.S(obj)
# Unit Conversion
- match (spux.uses_units(sp_obj), self.unit is not None):
+ match (spux.uses_units(res), self.unit is not None):
case (True, True):
- res = spux.scale_to_unit(sp_obj, self.unit) * self.unit
+ res = spux.scale_to_unit(res, self.unit) * self.unit
case (False, True):
- res = sp_obj * self.unit
+ res = res * self.unit
case (True, False):
- res = spux.strip_unit_system(sp_obj)
+ res = spux.strip_unit_system(res)
if strip_unit:
- res = spux.strip_unit_system(sp_obj)
+ res = spux.strip_unit_system(res)
# Broadcast Expansion
- if (self.rows > 1 or self.cols > 1) and not isinstance(
- res, sp.MatrixBase | sp.MatrixSymbol
+ if (
+ self.depths == ()
+ and (self.rows > 1 or self.cols > 1)
+ and not isinstance(res, sp.MatrixBase | sp.MatrixSymbol)
):
- res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
+ res = sp.ImmutableMatrix.ones(self.rows, self.cols).applyfunc(
+ lambda el: 5 * el
+ )
return res
+ @method_lru(maxsize=256)
def scale(
self, sp_obj: spux.SympyType, use_jax_array: bool = True
) -> int | float | complex | jtyp.Inexact[jtyp.Array, '...']:
@@ -769,12 +732,13 @@ class SimSymbol(pyd.BaseModel):
####################
# - Creation
####################
+ @functools.lru_cache(maxsize=512)
@staticmethod
def from_expr(
sym_name: SimSymbolName,
expr: spux.SympyExpr,
unit_expr: spux.SympyExpr,
- is_constant: bool = False,
+ new_domain: spux.BlessedSet | None = None,
optional: bool = False,
) -> typ.Self | None:
"""Deduce a `SimSymbol` that matches the output of a given expression (and unit expression).
@@ -802,7 +766,7 @@ class SimSymbol(pyd.BaseModel):
# MathType from Expr Assumptions
## -> All input symbols have assumptions, because we are very pedantic.
## -> Therefore, we should be able to reconstruct the MathType.
- mathtype = spux.MathType.from_expr(expr, optional=optional)
+ mathtype = MT.from_expr(expr, optional=optional)
if mathtype is None:
return None
@@ -830,6 +794,13 @@ class SimSymbol(pyd.BaseModel):
expr.shape if isinstance(expr, sp.MatrixBase | sp.MatrixSymbol) else (1, 1)
)
+ # Set Domain
+ domain = (
+ new_domain
+ if new_domain is not None
+ else spux.BlessedSet(mathtype.symbolic_set)
+ )
+
return SimSymbol(
sym_name=sym_name,
mathtype=mathtype,
@@ -837,8 +808,11 @@ class SimSymbol(pyd.BaseModel):
unit=unit,
rows=rows,
cols=cols,
- is_constant=is_constant,
- exclude_zero=expr.is_zero is not None and not expr.is_zero,
+ domain=(
+ domain - sp.FiniteSet(0)
+ if expr.is_zero is not None and not expr.is_zero
+ else domain
+ ),
)
####################
@@ -873,266 +847,3 @@ class SimSymbol(pyd.BaseModel):
A new instance of `SimSymbol`, initialized using the `model_dump()` dictionary.
"""
return SimSymbol(**obj[2])
-
-
-####################
-# - Common Sim Symbols
-####################
-class CommonSimSymbol(enum.StrEnum):
- """Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
-
- The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
-
- Attributes:
- X:
- Time: A symbol representing a real-valued wavelength.
- Wavelength: A symbol representing a real-valued wavelength.
- Implicitly, this symbol often represents "vacuum wavelength" in particular.
- Wavelength: A symbol representing a real-valued frequency.
- Generally, this is the non-angular frequency.
- """
-
- Index = enum.auto()
-
- # Space|Time
- SpaceX = enum.auto()
- SpaceY = enum.auto()
- SpaceZ = enum.auto()
-
- AngR = enum.auto()
- AngTheta = enum.auto()
- AngPhi = enum.auto()
-
- DirX = enum.auto()
- DirY = enum.auto()
- DirZ = enum.auto()
-
- Time = enum.auto()
-
- # Fields
- FieldEx = enum.auto()
- FieldEy = enum.auto()
- FieldEz = enum.auto()
- FieldHx = enum.auto()
- FieldHy = enum.auto()
- FieldHz = enum.auto()
-
- FieldEr = enum.auto()
- FieldEtheta = enum.auto()
- FieldEphi = enum.auto()
- FieldHr = enum.auto()
- FieldHtheta = enum.auto()
- FieldHphi = enum.auto()
-
- # Optics
- Wavelength = enum.auto()
- Frequency = enum.auto()
-
- Flux = enum.auto()
-
- DiffOrderX = enum.auto()
- DiffOrderY = enum.auto()
-
- ####################
- # - UI
- ####################
- @staticmethod
- def to_name(v: typ.Self) -> str:
- """Convert the enum value to a human-friendly name.
-
- Notes:
- Used to print names in `EnumProperty`s based on this enum.
-
- Returns:
- A human-friendly name corresponding to the enum value.
- """
- return CommonSimSymbol(v).name
-
- @staticmethod
- def to_icon(_: typ.Self) -> str:
- """Convert the enum value to a Blender icon.
-
- Notes:
- Used to print icons in `EnumProperty`s based on this enum.
-
- Returns:
- A human-friendly name corresponding to the enum value.
- """
- return ''
-
- ####################
- # - Properties
- ####################
- @property
- def name(self) -> str:
- SSN = SimSymbolName
- CSS = CommonSimSymbol
- return {
- CSS.Index: SSN.LowerI,
- # Space|Time
- CSS.SpaceX: SSN.LowerX,
- CSS.SpaceY: SSN.LowerY,
- CSS.SpaceZ: SSN.LowerZ,
- CSS.AngR: SSN.LowerR,
- CSS.AngTheta: SSN.LowerTheta,
- CSS.AngPhi: SSN.LowerPhi,
- CSS.DirX: SSN.LowerX,
- CSS.DirY: SSN.LowerY,
- CSS.DirZ: SSN.LowerZ,
- CSS.Time: SSN.LowerT,
- # Fields
- CSS.FieldEx: SSN.Ex,
- CSS.FieldEy: SSN.Ey,
- CSS.FieldEz: SSN.Ez,
- CSS.FieldHx: SSN.Hx,
- CSS.FieldHy: SSN.Hy,
- CSS.FieldHz: SSN.Hz,
- CSS.FieldEr: SSN.Er,
- CSS.FieldHr: SSN.Hr,
- # Optics
- CSS.Frequency: SSN.Frequency,
- CSS.Wavelength: SSN.Wavelength,
- CSS.DiffOrderX: SSN.DiffOrderX,
- CSS.DiffOrderY: SSN.DiffOrderY,
- }[self]
-
- def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
- """Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
- CSS = CommonSimSymbol
-
- # Space
- sym_space = SimSymbol(
- sym_name=self.name,
- physical_type=spux.PhysicalType.Length,
- unit=unit,
- )
- sym_ang = SimSymbol(
- sym_name=self.name,
- physical_type=spux.PhysicalType.Angle,
- unit=unit,
- )
-
- # Fields
- def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
- return SimSymbol(
- sym_name=self.name,
- physical_type=spux.PhysicalType.EField
- if eh == 'e'
- else spux.PhysicalType.HField,
- unit=unit,
- interval_finite_re=(0, float_max),
- interval_inf_re=(False, True),
- interval_closed_re=(True, False),
- interval_finite_im=(float_min, float_max),
- interval_inf_im=(True, True),
- )
-
- return {
- CSS.Index: SimSymbol(
- sym_name=self.name,
- mathtype=spux.MathType.Integer,
- interval_finite_z=(0, 2**64),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- # Space|Time
- CSS.SpaceX: sym_space,
- CSS.SpaceY: sym_space,
- CSS.SpaceZ: sym_space,
- CSS.AngR: sym_space,
- CSS.AngTheta: sym_ang,
- CSS.AngPhi: sym_ang,
- CSS.Time: SimSymbol(
- sym_name=self.name,
- physical_type=spux.PhysicalType.Time,
- unit=unit,
- interval_finite_re=(0, float_max),
- interval_inf=(False, True),
- interval_closed=(True, False),
- ),
- # Fields
- CSS.FieldEx: sym_field('e'),
- CSS.FieldEy: sym_field('e'),
- CSS.FieldEz: sym_field('e'),
- CSS.FieldHx: sym_field('h'),
- CSS.FieldHy: sym_field('h'),
- CSS.FieldHz: sym_field('h'),
- CSS.FieldEr: sym_field('e'),
- CSS.FieldEtheta: sym_field('e'),
- CSS.FieldEphi: sym_field('e'),
- CSS.FieldHr: sym_field('h'),
- CSS.FieldHtheta: sym_field('h'),
- CSS.FieldHphi: sym_field('h'),
- # Optics
- CSS.Wavelength: SimSymbol(
- sym_name=self.name,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Length,
- unit=unit,
- interval_finite=(0, float_max),
- interval_inf=(False, True),
- interval_closed=(False, False),
- ),
- CSS.Frequency: SimSymbol(
- sym_name=self.name,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Freq,
- unit=unit,
- interval_finite=(0, float_max),
- interval_inf=(False, True),
- interval_closed=(False, False),
- ),
- CSS.Flux: SimSymbol(
- sym_name=SimSymbolName.Flux,
- mathtype=spux.MathType.Real,
- physical_type=spux.PhysicalType.Power,
- unit=unit,
- ),
- CSS.DiffOrderX: SimSymbol(
- sym_name=self.name,
- mathtype=spux.MathType.Integer,
- interval_finite=(int_min, int_max),
- interval_inf=(True, True),
- interval_closed=(False, False),
- ),
- CSS.DiffOrderY: SimSymbol(
- sym_name=self.name,
- mathtype=spux.MathType.Integer,
- interval_finite=(int_min, int_max),
- interval_inf=(True, True),
- interval_closed=(False, False),
- ),
- }[self]
-
-
-####################
-# - Selected Direct-Access to SimSymbols
-####################
-idx = CommonSimSymbol.Index.sim_symbol
-t = CommonSimSymbol.Time.sim_symbol
-wl = CommonSimSymbol.Wavelength.sim_symbol
-freq = CommonSimSymbol.Frequency.sim_symbol
-
-space_x = CommonSimSymbol.SpaceX.sim_symbol
-space_y = CommonSimSymbol.SpaceY.sim_symbol
-space_z = CommonSimSymbol.SpaceZ.sim_symbol
-
-dir_x = CommonSimSymbol.DirX.sim_symbol
-dir_y = CommonSimSymbol.DirY.sim_symbol
-dir_z = CommonSimSymbol.DirZ.sim_symbol
-
-ang_r = CommonSimSymbol.AngR.sim_symbol
-ang_theta = CommonSimSymbol.AngTheta.sim_symbol
-ang_phi = CommonSimSymbol.AngPhi.sim_symbol
-
-field_ex = CommonSimSymbol.FieldEx.sim_symbol
-field_ey = CommonSimSymbol.FieldEy.sim_symbol
-field_ez = CommonSimSymbol.FieldEz.sim_symbol
-field_hx = CommonSimSymbol.FieldHx.sim_symbol
-field_hy = CommonSimSymbol.FieldHx.sim_symbol
-field_hz = CommonSimSymbol.FieldHx.sim_symbol
-
-flux = CommonSimSymbol.Flux.sim_symbol
-
-diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
-diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol
diff --git a/src/blender_maxwell/utils/sim_symbols/utils.py b/src/blender_maxwell/utils/sim_symbols/utils.py
new file mode 100644
index 0000000..a08300c
--- /dev/null
+++ b/src/blender_maxwell/utils/sim_symbols/utils.py
@@ -0,0 +1,44 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import sys
+from fractions import Fraction
+
+import sympy as sp
+
+int_min = -(2**64)
+int_max = 2**64
+float_min = sys.float_info.min
+float_max = sys.float_info.max
+
+
+def unicode_superscript(n: int) -> str:
+ """Transform an integer into its unicode-based superscript character."""
+ return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
+
+
+def mk_interval(
+ interval_finite: tuple[int | Fraction | float, int | Fraction | float],
+ interval_inf: tuple[bool, bool],
+ interval_closed: tuple[bool, bool],
+) -> sp.Interval:
+ """Create a symbolic interval from the tuples (and unit) defining it."""
+ return sp.Interval(
+ start=(interval_finite[0] if not interval_inf[0] else -sp.oo),
+ end=(interval_finite[1] if not interval_inf[1] else sp.oo),
+ left_open=(True if interval_inf[0] else not interval_closed[0]),
+ right_open=(True if interval_inf[1] else not interval_closed[1]),
+ )
diff --git a/src/blender_maxwell/utils/sympy_extra/__init__.py b/src/blender_maxwell/utils/sympy_extra/__init__.py
index 6ac4fca..187864c 100644
--- a/src/blender_maxwell/utils/sympy_extra/__init__.py
+++ b/src/blender_maxwell/utils/sympy_extra/__init__.py
@@ -41,6 +41,16 @@ from .sympy_expr import (
Unit,
UnitDimension,
)
+from .sympy_sets import (
+ BlessedDomainOp,
+ BlessedSet,
+ BlessedSetType,
+ ComplexRegion,
+ MatrixSet,
+ bless_set,
+ set_expr_op,
+ simplify_blessed_set,
+)
from .sympy_type import SympyType
from .unit_analysis import (
compare_unit_dims,
@@ -51,6 +61,7 @@ from .unit_analysis import (
scaling_factor,
strip_units,
unit_dim_to_unit_dim_deps,
+ unit_scaling_func_n,
unit_str_to_unit,
unit_to_unit_dim_deps,
uses_units,
@@ -123,6 +134,19 @@ __all__ = [
'SympyExpr',
'Unit',
'UnitDimension',
+ 'BlessedDomainOp',
+ 'BlessedSet',
+ 'BlessedSetType',
+ 'ComplexRegion',
+ 'MatrixSet',
+ 'bless_set',
+ 'set_expr_op',
+ 'simplify_blessed_set',
+ 'BlessedSet',
+ 'minkowski_add',
+ 'minkowski_sub',
+ 'set_element_div',
+ 'set_element_mul',
'SympyType',
'compare_unit_dims',
'compare_units_by_unit_dims',
@@ -132,6 +156,7 @@ __all__ = [
'scaling_factor',
'strip_units',
'unit_dim_to_unit_dim_deps',
+ 'unit_scaling_func_n',
'unit_str_to_unit',
'unit_to_unit_dim_deps',
'uses_units',
diff --git a/src/blender_maxwell/utils/sympy_extra/math_type.py b/src/blender_maxwell/utils/sympy_extra/math_type.py
index 8830d85..91495d6 100644
--- a/src/blender_maxwell/utils/sympy_extra/math_type.py
+++ b/src/blender_maxwell/utils/sympy_extra/math_type.py
@@ -32,6 +32,8 @@ from .sympy_type import SympyType
log = logger.get(__name__)
+MatrixSet: typ.TypeAlias = sp.matrices.MatrixSet
+
class MathType(enum.StrEnum):
"""A convenient, UI-friendly identifier of a numerical object's identity."""
@@ -131,6 +133,7 @@ class MathType(enum.StrEnum):
@staticmethod
def from_pytype(dtype: type) -> type:
+ """The `MathType` corresponding to a particular pure-Python type."""
return {
int: MathType.Integer,
Fraction: MathType.Rational,
@@ -213,15 +216,7 @@ class MathType(enum.StrEnum):
@staticmethod
def from_symbolic_set(
- s: typ.Literal[
- sp.Naturals
- | sp.Naturals0
- | sp.Integers
- | sp.Rationals
- | sp.Reals
- | sp.Complexes
- ]
- | sp.Set,
+ s: sp.Set,
optional: bool = False,
) -> typ.Self | None:
"""Deduce the `MathType` from a particular symbolic set.
@@ -243,6 +238,24 @@ class MathType(enum.StrEnum):
case sp.Complexes:
return MT.Complex
+ if isinstance(s, sp.ProductSet):
+ return MT.combine(*[MT.from_symbolic_set(arg) for arg in s.sets])
+
+ if isinstance(s, MatrixSet):
+ return MT.from_symbolic_set(s.set)
+
+ valid_mathtype = MT.Complex
+ for test_set, mathtype in [
+ (sp.Complexes, MT.Complex),
+ (sp.Reals, MT.Real),
+ (sp.Rationals, MT.Rational),
+ (sp.Integers, MT.Integer),
+ ]:
+ if s.issubset(test_set):
+ valid_mathtype = mathtype
+ else:
+ return valid_mathtype
+
if optional:
return None
@@ -263,6 +276,17 @@ class MathType(enum.StrEnum):
MT.Complex: complex,
}[self]
+ @property
+ def dtype(self) -> type:
+ """Deduce the type that corresponds to this `MathType`, which is usable with `numpy`/`jax`."""
+ MT = MathType
+ return {
+ MT.Integer: int,
+ MT.Rational: float,
+ MT.Real: float,
+ MT.Complex: complex,
+ }[self]
+
@property
def inf_finite(self) -> type:
"""Opinionated finite representation of "infinity" within this `MathType`.
diff --git a/src/blender_maxwell/utils/sympy_extra/physical_type.py b/src/blender_maxwell/utils/sympy_extra/physical_type.py
index 2adedda..758f0d2 100644
--- a/src/blender_maxwell/utils/sympy_extra/physical_type.py
+++ b/src/blender_maxwell/utils/sympy_extra/physical_type.py
@@ -242,7 +242,9 @@ class PhysicalType(enum.StrEnum):
# - Creation
####################
@staticmethod
- def from_unit(unit: Unit | None, optional: bool = False) -> typ.Self | None:
+ def from_unit(
+ unit: Unit | None, optional: bool = False, optional_nonphy: bool = False
+ ) -> typ.Self | None:
"""Attempt to determine a matching `PhysicalType` from a unit.
NOTE: It is not guaranteed that `unit` is within `valid_units`, only that it can be converted to any unit in `valid_units`.
@@ -258,7 +260,7 @@ class PhysicalType(enum.StrEnum):
if unit is None:
return PhysicalType.NonPhysical
- ## TODO_ This enough?
+ ## TODO: This enough?
if unit in [spu.radian, spu.degree]:
return PhysicalType.Angle
@@ -269,6 +271,8 @@ class PhysicalType(enum.StrEnum):
return physical_type
if optional:
+ if optional_nonphy:
+ return PhysicalType.NonPhysical
return None
msg = f'Could not determine PhysicalType for {unit}'
raise ValueError(msg)
@@ -303,9 +307,11 @@ class PhysicalType(enum.StrEnum):
# - Valid Properties
####################
@functools.cached_property
- def valid_units(self) -> list[Unit]:
+ def valid_units(self) -> list[Unit | None]:
"""Retrieve an ordered (by subjective usefulness) list of units for this physical type.
+ `None` denotes "no units are valid".
+
Warnings:
**Altering the order of units hard-breaks backwards compatibility**, since enums based on it only keep an integer index.
@@ -463,29 +469,32 @@ class PhysicalType(enum.StrEnum):
}[self]
@functools.cached_property
- def valid_shapes(self) -> list[typ.Literal[(3,), (2,)] | None]:
- """All shapes with physical meaning in the context of a particular unit dimension."""
+ def valid_shapes(self) -> list[typ.Literal[(3,), (2,), ()] | None]:
+ """All shapes with physical meaning in the context of a particular unit dimension.
+
+ Don't use with `NonPhysical`.
+ """
PT = PhysicalType
overrides = {
# Cartesian
- PT.Length: [None, (2,), (3,)],
+ PT.Length: [(), (2,), (3,)],
# Mechanical
- PT.Vel: [None, (2,), (3,)],
- PT.Accel: [None, (2,), (3,)],
- PT.Force: [None, (2,), (3,)],
+ PT.Vel: [(), (2,), (3,)],
+ PT.Accel: [(), (2,), (3,)],
+ PT.Force: [(), (2,), (3,)],
# Energy
- PT.Work: [None, (2,), (3,)],
- PT.PowerFlux: [None, (2,), (3,)],
+ PT.Work: [(), (2,), (3,)],
+ PT.PowerFlux: [(), (2,), (3,)],
# Electrodynamics
- PT.CurrentDensity: [None, (2,), (3,)],
- PT.MFluxDensity: [None, (2,), (3,)],
- PT.EField: [None, (2,), (3,)],
- PT.HField: [None, (2,), (3,)],
+ PT.CurrentDensity: [(), (2,), (3,)],
+ PT.MFluxDensity: [(), (2,), (3,)],
+ PT.EField: [(), (2,), (3,)],
+ PT.HField: [(), (2,), (3,)],
# Luminal
- PT.LumFlux: [None, (2,), (3,)],
+ PT.LumFlux: [(), (2,), (3,)],
}
- return overrides.get(self, [None])
+ return overrides.get(self, [()])
@functools.cached_property
def valid_mathtypes(self) -> list[MathType]:
diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_expr.py b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py
index cafd421..b518dd2 100644
--- a/src/blender_maxwell/utils/sympy_extra/sympy_expr.py
+++ b/src/blender_maxwell/utils/sympy_extra/sympy_expr.py
@@ -137,6 +137,23 @@ SympyExpr = typx.Annotated[
## TODO: The type game between SympyType, SympyExpr, and the various flavors of ConstrSympyExpr(), is starting to be a bit much. Let's consolidate.
+def SympyObj(instance_of: set[typ.Any]) -> typx.Annotated: # noqa: N802
+ """Declare that a sympy object guaranteed to be an instance of the given bases."""
+
+ def validate_sp_obj(sp_obj: SympyType):
+ if any(isinstance(sp_obj, Base) for Base in instance_of):
+ return sp_obj
+
+ msg = f'Sympy object {sp_obj} is not an instance of a specified valid base {instance_of}.'
+ raise ValueError(msg)
+
+ return typx.Annotated[
+ sp.Basic,
+ _SympyExpr,
+ pyd.AfterValidator(validate_sp_obj),
+ ]
+
+
def ConstrSympyExpr( # noqa: N802, PLR0913
# Features
allow_variables: bool = True,
@@ -237,7 +254,7 @@ def ConstrSympyExpr( # noqa: N802, PLR0913
####################
-# - Common ConstrSympyExpr
+# - Numbers
####################
# Expression
ScalarUnitlessRealExpr: typ.TypeAlias = ConstrSympyExpr(
diff --git a/src/blender_maxwell/utils/sympy_extra/sympy_sets.py b/src/blender_maxwell/utils/sympy_extra/sympy_sets.py
new file mode 100644
index 0000000..9d354d4
--- /dev/null
+++ b/src/blender_maxwell/utils/sympy_extra/sympy_sets.py
@@ -0,0 +1,1491 @@
+# blender_maxwell
+# Copyright (C) 2024 blender_maxwell Project Contributors
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Implements a wrapper for the use of `sympy` sets in deducing the image of a key selection of functions/operations.
+
+Using sets to represent valid domains of symbols delves into territory that `sympy` only theoretically supports.
+The internal `sympy.sets.setexpr.SetExpr`, which drives most of the available simplifications, is only of help in some of the absolute most simple cases.
+To remedy this, we've decided to "bless" a few sets that are absolutely essential for our needs.
+
+In total, we're left with a distinctly usable object capable of tracking symbolic domains of validity through extensive mathematical operations.
+"""
+
+import functools
+import itertools
+import operator
+import typing as typ
+from fractions import Fraction
+
+import jax
+import jax.numpy as jnp
+import pydantic as pyd
+import sympy as sp
+from sympy.sets.setexpr import SetExpr
+
+from blender_maxwell.utils.lru_method import method_lru
+
+from .. import logger
+from .math_type import MathType as MT # noqa: N817
+from .sympy_expr import ScalarUnitlessComplexExpr, SympyExpr
+
+log = logger.get(__name__)
+
+
+####################
+# - Types
+####################
+Scalar: typ.TypeAlias = ScalarUnitlessComplexExpr
+MatrixSet: typ.TypeAlias = sp.matrices.MatrixSet
+ComplexRegion: typ.TypeAlias = sp.sets.fancysets.CartesianComplexRegion
+
+# Valid BlessedSet Types:
+## Trivial:
+## - sp.EmptySet: Null Set
+## Points:
+## - sp.FiniteSet ## -> For some reason, no work, but it's blessed too.
+## Region:
+## - sp.Range (stride=1): [a,b]_Z w/1-spacing
+## - sp.Interval
+## - ComplexRegion: tuple[.args[0], .args[1]] w/blessed args.
+## - sp.Complexes: Inf ComplexRegion unfortunately simplifies to Complexes.
+## Composite:
+## - sp.Union[*, *, ...]: Arbitrary depth.
+## TODO: Bless ConditionSets (for arbitrarily expressive domain shapes).
+
+BlessedSetType: typ.TypeAlias = SympyExpr # SympyObj(
+# instance_of={
+# sp.EmptySet,
+# sp.FiniteSet,
+# sp.Range,
+# sp.Interval,
+# ComplexRegion,
+# sp.Complexes,
+# sp.Union,
+# }
+# )
+BlessedDomainOp: typ.TypeAlias = typ.Literal[
+ operator.add, operator.sub, operator.mul, operator.truediv, operator.pow
+]
+
+# The following set types are "blessable" (aka. parseable to a BlessedSet):
+## Universal:
+## - sp.Naturals: Range(1, oo)
+## - sp.Naturals0: Range(0, oo)
+## - sp.Integers: Range(-oo, oo)
+## - sp.Rationals: Interval(-oo, oo) ## TODO: Custom rationals interval?
+## - sp.Reals: Interval(-oo, oo)
+## Composite:
+## - MatrixSet[*]: Any parseable set.
+
+
+####################
+# - Bless Arbitrary Sets
+####################
+@functools.lru_cache(maxsize=8192)
+def simplify_blessed_set(s: BlessedSetType) -> BlessedSetType:
+ """Perform some principled simplifications on computed `BlessedSetType`s."""
+ # log.critical(s)
+
+ sset = s
+
+ if isinstance(sset, sp.FiniteSet):
+ sset -= {-sp.oo, sp.oo, -sp.zoo, sp.zoo}
+ # sset = sp.FiniteSet(*[sp.nsimplify(fs_el, tolerance=10**-7) for fs_el in sset])
+
+ elif isinstance(sset, sp.Range):
+ if sset.start == 0 and sset.stop == 0:
+ return sp.EmptySet
+
+ elif isinstance(sset, sp.Interval):
+ return sp.Interval(
+ sset.start,
+ sset.end,
+ # sp.nsimplify(sset.start, tolerance=10**-7),
+ # sp.nsimplify(sset.end, tolerance=10**-7),
+ left_open=sset.left_open,
+ right_open=sset.right_open,
+ )
+ if sset.start == 0 and sset.stop == 0:
+ return sp.EmptySet
+
+ elif (
+ isinstance(sset, ComplexRegion)
+ and isinstance(sset.b_interval, sp.FiniteSet)
+ and sset.b_interval == sp.FiniteSet(0)
+ ):
+ sset = s.interval_a
+
+ elif isinstance(sset, sp.Union):
+ sset = sp.Union(*(simplify_blessed_set(arg) for arg in s.args))
+
+ return sset
+
+
+@functools.lru_cache(maxsize=8192)
+def bless_set(s: sp.Set) -> BlessedSetType: # noqa: C901, PLR0911, PLR0912, PLR0915
+ """Attempt to conform an arbitrary input set to an equivalent `BlessedSetType`.
+
+ This may not always be possible.
+ Therefore, as a fallback, we detect the `MT` of the given set and return its corresponding symbolic set.
+ """
+ # log.critical(s)
+
+ if isinstance(s, set | frozenset):
+ return sp.S(s)
+ # Explicitly Parse
+ ## -> There are some sets we don't want to special-case later.
+ if s is sp.Naturals:
+ return sp.Range(1, sp.oo, 1)
+ if s is sp.Naturals0:
+ return sp.Range(0, sp.oo, 1)
+ if s is sp.Integers:
+ return sp.Range(-sp.oo, sp.oo, 1)
+ if s is sp.Rationals or s is sp.Reals:
+ return sp.Interval(-sp.oo, sp.oo)
+ if s is sp.Complexes:
+ return sp.Complexes
+ ## -> ComplexRegion(R*R) unfortunately simplifies to Complexes.
+ ## -> This means that we unfortunately must resort to use of a tuple.
+ return (sp.Interval(-sp.oo, sp.oo), sp.Interval(-sp.oo, sp.oo))
+
+ if isinstance(s, MatrixSet):
+ return bless_set(s.set)
+
+ # Explicitly Refuse
+ ## -> There are some sets we simply cannot transcribe.
+ error = False
+ if isinstance(s, sp.Range) and s.step != 1:
+ log.error(
+ 'Range set %s w/range != 1 cannot be parsed; expanding to step size of 1',
+ str(s),
+ )
+ return simplify_blessed_set(sp.Range(s.start, s.stop, 1))
+
+ if isinstance(s, sp.Intersection):
+ log.error('Intersection set %s cannot be parsed', str(s))
+ error = True
+
+ elif isinstance(s, sp.ProductSet):
+ log.error('Product set %s cannot be parsed', str(s))
+ raise NotImplementedError
+
+ elif isinstance(s, sp.Complement):
+ # Remove Segments from Range
+ ## -> We construct an equivalent union of disjoint ranges.
+ ## -> We search for start/end points by linear search w/'in' operator.
+ ## -> We determine removed points by simple set intersection.
+ ## -> This works with absolutely any sp.Set - even bs like ImageSet!
+ if isinstance(s.args[0], sp.Range):
+ rng = s.args[0]
+ other = s.args[1]
+
+ start = next(el for el in rng if el not in other)
+ end = next(el for el in reversed(rng) if el not in other)
+ removed = rng & other
+
+ return bless_set(
+ functools.reduce(
+ lambda S, rng_i: S | rng_i, # noqa: N803
+ [
+ sp.Range(s + 1, e, 1) if s != e else sp.EmptySet
+ for s, e in itertools.pairwise(
+ sp.FiniteSet(start - 1) | removed | sp.FiniteSet(end + 1)
+ )
+ ],
+ )
+ )
+
+ # Punch Finite "Holes" in Interval
+ ## -> Interval - Interval simplifies just fine.
+ ## -> BUT, Sympy gives up when trying to "punch holes" using integers.
+ if isinstance(s.args[0], sp.Interval) and isinstance(
+ s.args[1], sp.FiniteSet | sp.Range
+ ):
+ ivl = s.args[0]
+ other = s.args[1]
+
+ removed = ivl & other
+ holes = sp.S(set(removed))
+
+ return functools.reduce(lambda s, hole: s - {hole}, holes, ivl)
+
+ # Split ComplexRegion
+ ## -> Just split along each axis.
+ if isinstance(s.args[0], ComplexRegion):
+ cpx = s.args[0]
+ other = s.args[1]
+
+ removed = cpx & other
+ return bless_set(
+ ComplexRegion(
+ bless_set(cpx.a_interval - removed.a_interval)
+ * bless_set(cpx.b_interval - removed.b_interval),
+ )
+ )
+
+ log.error('Complement set %s cannot be parsed', str(s))
+ error = True
+
+ elif isinstance(s, sp.SymmetricDifference):
+ log.error('Symmetric Difference set %s cannot be parsed', str(s))
+ error = True
+
+ elif isinstance(s, sp.DisjointUnion):
+ log.error('Disjoint Union set %s cannot be parsed', str(s))
+ raise NotImplementedError
+
+ elif s is sp.UniversalSet:
+ log.error('Universal set %s cannot be parsed', str(s))
+ raise NotImplementedError
+
+ elif isinstance(s, sp.sets.fancysets.ImageSet):
+ log.error('Image set %s cannot be parsed', str(s))
+ raise NotImplementedError
+
+ elif isinstance(
+ s, sp.sets.fancysets.PolarComplexRegion | sp.sets.powerset.PowerSet
+ ):
+ log.error('Polar complex region set %s cannot be parsed', str(s))
+ error = True
+
+ elif isinstance(s, sp.sets.conditionset.ConditionSet):
+ log.error('Condition set %s cannot be parsed', str(s))
+ error = True
+
+ if error:
+ return bless_set(MT.from_symbolic_set(s).symbolic_set)
+ return simplify_blessed_set(s)
+
+
+####################
+# - SetExpr Operations
+####################
+@functools.lru_cache(maxsize=8192)
+def set_expr_op(
+ op: BlessedDomainOp, lhs: BlessedSetType, rhs: BlessedSetType
+) -> BlessedSetType:
+ """Compute the elementwise between two sets using `SetExpr`.
+
+ It is the user's responsibility to ensure the particular operation between `lhs` and `rhs` can actually be evaluated by `SetExpr`.
+ While the constraint of `BlessedDomainOp` is a prerequisite, there is absolutely no guarantee that `SetExpr` can evaluate.
+
+ Raises:
+ TypeError: If the SetExpr did not simplify away the internal `ImageSet`, causing the computed set to not to be usable as a `BlessedSetType`.
+ """
+ computed_set = simplify_blessed_set(op(SetExpr(lhs), SetExpr(rhs)).set)
+ if isinstance(computed_set, sp.ImageSet):
+ msg = 'SetExpr did not evaluate for {op} between {lhs} and {rhs}'
+ raise TypeError(msg)
+
+ return computed_set
+
+
+####################
+# - Blessed Set
+####################
+## -> Generated using 'sp.nsolve(sp.diff(sp.sinc(x)), x, -5.0)'
+SINC_MINIMUM = -4.49340945790906
+
+
+class BlessedSet(pyd.BaseModel):
+ """A wrapper enabling reliable elementwise operations on and between a constrained set of `sp.Set`s.
+
+ In particular, it provides correct implementations of:
+
+ - Enriched Set Operations: The usual `|` and `&`, along with a few extra well-defined semantics.
+ - Unary Operations: Compute the set resulting from the application of a particular, single-argument function.
+ - Binary Operations w/Scalar: Compute the set resulting from the application of an operation between all elements of a set, and a given scalar.
+ - Minkowski Operations: Compute the set resulting from the application of an operation between all elements of two sets.
+ """
+
+ model_config = pyd.ConfigDict(frozen=True)
+
+ bset: BlessedSetType
+
+ ####################
+ # - Creation
+ ####################
+ def __init__(self, bset: typ.Self | sp.Set | set) -> None:
+ if isinstance(bset, BlessedSet):
+ super().__init__(bset=bset.bset)
+ elif isinstance(bset, set | frozenset):
+ super().__init__(bset=bless_set(frozenset(bset)))
+ super().__init__(bset=bless_set(bset))
+
+ @functools.lru_cache(maxsize=8192)
+ @staticmethod
+ def reals_to_complex(a: typ.Self | sp.Set, b: typ.Self | sp.Set) -> None:
+ # log.critical([a, b])
+ BS = BlessedSet
+ return BS(ComplexRegion(BS(a).bset * BS(b).bset))
+
+ ####################
+ # - Sympy
+ ####################
+ def _sympy_(self) -> typ.Self:
+ return self.bset
+
+ ####################
+ # - Properties
+ ####################
+ @functools.cached_property
+ def real(self) -> typ.Self:
+ """The real subset of the represented set."""
+ if isinstance(self.bset, ComplexRegion):
+ return BlessedSet(self.bset.a_interval)
+ return self
+
+ @functools.cached_property
+ def imag(self) -> typ.Self:
+ """The imaginary subset of the represented set."""
+ if isinstance(self.bset, ComplexRegion):
+ return BlessedSet(self.bset.b_interval)
+ return BlessedSet(sp.FiniteSet(0))
+
+ @functools.cached_property
+ def is_empty(self) -> typ.Self:
+ """Whether this is the null set."""
+ return self.bset is sp.EmptySet
+
+ @functools.cached_property
+ def is_nonzero(self) -> typ.Self:
+ """Whether `0` occurs in this set."""
+ return 0 in self.bset
+
+ @functools.cached_property
+ def inf(self) -> typ.Self:
+ """The largest lower bound on this set.
+
+ For complex sets, we define this across both real and imaginary axes.
+ """
+ if self.bset is sp.EmptySet:
+ msg = 'Empty set has no infimum'
+ raise TypeError(msg)
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range | sp.Interval):
+ return self.bset.inf
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.real.inf + sp.I * self.real.inf
+
+ if isinstance(self.bset, sp.Complexes):
+ return -sp.zoo
+
+ if isinstance(self.bset, sp.Union):
+ return min([BlessedSet(subset).inf for subset in self.bset.args])
+
+ raise TypeError
+
+ @functools.cached_property
+ def sup(self) -> typ.Self:
+ """The smallest upper bound on this set.
+
+ For complex sets, we define this across both real and imaginary axes.
+ """
+ if self.bset is sp.EmptySet:
+ msg = 'Empty set has no infimum'
+ raise TypeError(msg)
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range | sp.Interval):
+ return self.bset.sup
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.real.sup + sp.I * self.imag.sup
+
+ if isinstance(self.bset, sp.Complexes):
+ return sp.zoo
+
+ if isinstance(self.bset, sp.Union):
+ return min([BlessedSet(subset).inf for subset in self.bset.args])
+
+ raise TypeError
+
+ @functools.cached_property
+ def min_closed(self) -> typ.Self:
+ """The closure of the largest lower bound.
+
+ For complex sets, this refers to the closure of the real part.
+ """
+ if self.bset is sp.EmptySet or self.bset is sp.Complexes:
+ return False
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range):
+ return True
+
+ if isinstance(self.bset, sp.Interval):
+ return not self.bset.left_open
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.real.min_closed
+
+ if isinstance(self.bset, sp.Union):
+ real_subsets = sorted(
+ [BlessedSet(subset).real for subset in self.bset.args],
+ lambda els: els.inf,
+ )
+ return real_subsets[0].inf in real_subsets[0]
+
+ raise TypeError
+
+ @functools.cached_property
+ def max_closed(self) -> typ.Self:
+ """The closure of the smallest upper bound.
+
+ For complex sets, this refers to the closure of the real part.
+ """
+ if self.bset is sp.EmptySet or self.bset is sp.Complexes:
+ return False
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range):
+ return True
+
+ if isinstance(self.bset, sp.Interval):
+ return not self.bset.right_open
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.real.max_closed
+
+ if isinstance(self.bset, sp.Union):
+ real_subsets = sorted(
+ [BlessedSet(subset).real for subset in self.bset.args],
+ lambda els: els.sup,
+ )
+ return real_subsets[0].sup in real_subsets[0]
+
+ raise TypeError
+
+ @functools.cached_property
+ def min_closed_im(self) -> bool:
+ """The closure of the largest lower bound of the imaginary axis.
+
+ Purely real sets will generally have closed imaginary axes, as it is just the point `0`.
+ """
+ if self.bset is sp.EmptySet or self.bset is sp.Complexes:
+ return False
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range | sp.Interval):
+ return True ## {0} is closed
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.imag.min_closed
+
+ if isinstance(self.bset, sp.Union):
+ imag_subsets = sorted(
+ [BlessedSet(subset).imag for subset in self.bset.args],
+ lambda els: els.inf,
+ )
+ return imag_subsets[0].inf in imag_subsets[0]
+
+ raise TypeError
+
+ @functools.cached_property
+ def max_closed_im(self) -> bool:
+ """The closure of the smallest upper bound of the imaginary axis.
+
+ Purely real sets will generally have closed imaginary axes, as it is just the point `0`.
+ """
+ if self.bset is sp.EmptySet or self.bset is sp.Complexes:
+ return False
+
+ if isinstance(self.bset, sp.FiniteSet | sp.Range | sp.Interval):
+ return True ## {0} is closed
+
+ if isinstance(self.bset, ComplexRegion):
+ return self.imag.max_closed
+
+ if isinstance(self.bset, sp.Union):
+ imag_subsets = sorted(
+ [BlessedSet(subset).imag for subset in self.bset.args],
+ lambda els: els.inf,
+ )
+ return imag_subsets[0].sup in imag_subsets[0]
+
+ raise TypeError
+
+ ####################
+ # - Methods
+ ####################
+ @method_lru()
+ def bset_mat(self, rows: int, cols: int) -> typ.Self:
+ # log.critical([self, rows, cols])
+ if rows == cols == 1:
+ return self.bset
+ return MatrixSet(rows, cols, BlessedSet(self.bset))
+
+ def sample_uniform_jax(self, key, sample_shape: tuple[int, ...]) -> jax.Array:
+ """Sample elements uniformly from this set, returning a `jax.Array` of the specified `sample_shape`.
+
+ - `FiniteSet`: Select from individual elements.
+ - `Range`: Select from integers within the range.
+ - `Interval`: Sample from the interior.
+ - `ComplexRegion`: Sample real/imag axes seperately and combine to a complex array.
+ - `Union` of `Interval`: Sample each disjoint `Interval` seperately, then use their relative size to weigh which individually sampled `Interval` to select an element from.
+ """
+ if self.bset is sp.EmptySet:
+ msg = 'Cant sample an empty set'
+ raise TypeError(msg)
+
+ if isinstance(self.bset, sp.FiniteSet):
+ if (
+ -sp.oo in self.bset
+ or sp.oo in self.bset
+ or -sp.zoo in self.bset
+ or sp.zoo in self.bset
+ ):
+ msg = 'Cant uniformly sample a finite set with infinite elements.'
+ raise ValueError(msg)
+
+ mathtype = MT.combine(self.bset)
+ finite_domain = jnp.array(list(self.bset), dtype=mathtype.dtype)
+
+ return jax.random.choice(key, sample_shape, a=finite_domain)
+
+ if isinstance(self.bset, sp.Range):
+ if self.bset.start == -sp.oo or self.bset.stop == sp.oo:
+ msg = 'Cant uniformly sample a range with infinite bounds.'
+ raise ValueError(msg)
+
+ start = int(self.bset.start)
+ stop_excl = int(self.bset.stop)
+
+ return jax.random.randint(key, sample_shape, minval=start, maxval=stop_excl)
+
+ if isinstance(self.bset, sp.Interval):
+ if self.bset.start == -sp.oo or self.bset.end == sp.oo:
+ msg = 'Cant uniformly sample an infinite Interval.'
+ raise ValueError(msg)
+
+ start = float(self.bset.start)
+ end = float(self.bset.end)
+
+ return jax.random.uniform(
+ key,
+ sample_shape,
+ minval=start,
+ maxval=end,
+ )
+
+ if isinstance(self.bset, ComplexRegion):
+ re = BlessedSet(self.bset.a_interval).sample_uniform_jax(sample_shape)
+ im = BlessedSet(self.bset.b_interval).sample_uniform_jax(sample_shape)
+
+ return re + 1j * im
+
+ if isinstance(self.bset, sp.Complexes):
+ msg = 'Cant uniformly sample a set with infinite elements.'
+ raise TypeError(msg)
+
+ if isinstance(self.bset, sp.Union):
+ if all(
+ isinstance(disjoint_subset, sp.Interval)
+ for disjoint_subset in self.bset.args
+ ):
+ # Deduce Subkeys
+ ## -> This enables deterministic sampling of subsets.
+ subkeys = jax.random.split(key, len(self.bset.args) + 1)
+
+ # Deduce Relative Weights
+ ## -> This enables proper weighting of subset samples.
+ weights = jnp.array(
+ [
+ disjoint_subset.sup - disjoint_subset.inf
+ for disjoint_subset in self.bset.args
+ ]
+ )
+ weights /= sum(weights)
+
+ # Assemble Uniform Samples of Disjoint Subsets
+ ## -> Disjointness allows independent sampling.
+ disjoint_subset_samples = [
+ BlessedSet(disjoint_subset).sample_uniform_jax(
+ subkeys[i], sample_shape
+ )
+ for i, disjoint_subset in enumerate(self.bset.args)
+ ]
+
+ # Perform Weighted Selection of Disjoint Uniform Samplings
+ ## -> We randomly generate a weighted index array.
+ ## -> This produces an index per-sample_shape pos.
+ ## -> This per-pos index selects values from disjoint samplings.
+ ## -> Thus, we've achieved weighted uniform disjoint sampling.
+ weighted_idx_sampling = jax.random.choice(
+ subkeys[-1],
+ shape=sample_shape,
+ a=jnp.arange(0, len(self.bset.args)),
+ p=weights,
+ )
+ return jnp.choose(
+ weighted_idx_sampling, choices=disjoint_subset_samples
+ )
+
+ raise NotImplementedError
+
+ raise TypeError
+
+ ####################
+ # - Set Operations
+ ####################
+ def __contains__(self, value: typ.Self | sp.MatrixBase | typ.Any) -> typ.Self:
+ """Deduce whether `value` is contained within this BlessedSet.
+
+ Generally, `in` only refers to "membership".
+ However, we provide two special conveniences to make it easier to work with `BlessedSet`s in practice.
+
+ - Breaking from the pure mathematical interpretation of membership, `value` is a `BlessedSet`, then it is considered to be "contained" within `self` it it is either the `sp.EmptySet` **or a subset** of `self.bset`.
+ - If `value` is itself a `MatrixBase`, then we will wrap `self.bset` in an appropriately shaped `MatrixSet` before determining membership.
+ """
+ # log.critical([self, value])
+ # BlessedSet: Compute Subset
+ ## -> In mathematics, 'has me as a subset' is an explicit op.
+ ## -> Generally, there's a difference btwn "element of" and "subset".
+ ## -> BUT, here, it can be well-defined of a 'BlessedSet' wrapper.
+ ## -> Therefore, we can match Python's ex. string behavior.
+ ## -> This convenience eliminates a ton of boilerplate.
+ ## -> NOTE: The EmptySet check is also easy to miss.
+ if isinstance(value, BlessedSet):
+ return value.bset is sp.EmptySet or value.bset.issubset(self.bset)
+
+ # MatrixSet: Deduce Shape-Uniform Membership
+ ## -> In our system, we've decided that matrices must have uniform dms.
+ ## -> This prevents a lot of very elaborate matrix domain math.
+ ## -> This also lets us express matrix domain math as normal domain math.
+ ## -> However, we still need to work with matrix domains _as matrices_.
+ ## -> Aka. 'sp.Matrix([1, 2, 3]) in BlessedSet(sp.Reals)' must work.
+ ## -> This avoids a ton of fragile boilerplate everywhere.
+ if isinstance(value, sp.MatrixBase):
+ return value in MatrixSet(*value.shape, self.bset)
+
+ return value in self.bset
+
+ def __or__(self, other: typ.Self | sp.Set | set) -> typ.Self:
+ """Compute the set union."""
+ # log.critical([self, other])
+ if isinstance(other, BlessedSet):
+ return BlessedSet(self.bset | other.bset)
+ return BlessedSet(self.bset | other)
+
+ def __and__(self, other: typ.Self | sp.Set | set) -> typ.Self:
+ """Compute the set intersection."""
+ # log.critical([self, other])
+ if isinstance(other, BlessedSet):
+ return BlessedSet(self.bset & other.bset)
+ return BlessedSet(self.bset & other)
+
+ ####################
+ # - Unary Operations
+ ####################
+ def __abs__(self) -> typ.Self:
+ return self.abs
+
+ @functools.cached_property
+ def abs(self) -> typ.Self:
+ """Apply the absolute value to all set elements."""
+ # log.critical(self)
+
+ s = self.bset
+ nset = s
+
+ # Trivial
+ if s is sp.EmptySet:
+ nset = sp.EmptySet
+
+ # Points
+ elif isinstance(s, sp.FiniteSet):
+ nset = sp.FiniteSet(*[abs(el) for el in self.bset])
+
+ # Ranges
+ elif isinstance(s, sp.Range):
+ start, _, end = sorted([sp.S(0), abs(s.inf), abs(s.sup)])
+ if start == end: # noqa: SIM108
+ nset = sp.FiniteSet(start)
+ else:
+ nset = sp.Range(start, end + 1)
+
+ elif isinstance(s, sp.Interval):
+ (start, start_open), (end, end_open) = sorted(
+ [
+ (abs(s.start), s.left_open),
+ (abs(s.end), s.right_open),
+ ## Can't use the same trick, as open/closed bounds don't sort.
+ ],
+ key=lambda el: el[0],
+ )
+
+ if 0 in s:
+ nset = sp.Interval(0, end, left_open=False, right_open=end_open)
+ else:
+ nset = sp.Interval(
+ start,
+ end,
+ left_open=start_open,
+ right_open=end_open,
+ )
+
+ elif isinstance(s, ComplexRegion):
+ nset = ((self.real**2 + self.imag**2) ** sp.Rational(1, 2)).bset
+
+ elif s is sp.Complexes:
+ nset = sp.Interval(0, sp.oo)
+
+ # Union: Recurse
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(*[abs(arg).bset for arg in s.args])
+
+ else:
+ msg = f'abs() not implemented for set {s}'
+ raise TypeError(msg)
+
+ return BlessedSet(nset)
+
+ @functools.cached_property
+ def reciprocal(self) -> typ.Self:
+ """Compute the `BlessedSet` resulting from applying the reciprocal function $1/x$ to each element."""
+ s = self.bset
+ nset = s
+
+ # Trivial
+ ## -> Sympy considers zero-divison as sp.S(x) / 0 = sp.zoo.
+ ## -> This 'complex infinity' is just sign-matching complex reals.
+ ## -> BUT, we need to do numerical stuff, and 'inf' is a problem.
+ ## -> Therefore, RESET EVERYTHING to EmptySet if */0 can happen.
+ if s is sp.EmptySet or 0 in s:
+ nset = sp.EmptySet
+
+ # Recursive
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(*[BlessedSet(arg).reciprocal.bset for arg in s.args])
+
+ # Points
+ elif isinstance(s, sp.FiniteSet):
+ nset = sp.FiniteSet(*[1 / el for el in s])
+
+ # Ranges
+ elif isinstance(s, sp.Range):
+ start, end = sorted([1 / s.inf, 1 / s.sup])
+ if start == end:
+ nset = sp.FiniteSet(start)
+ elif start.is_integer and end.is_integer:
+ nset = sp.Range(start, end + 1, 1)
+ else:
+ nset = sp.Interval(start, end, 1)
+
+ elif isinstance(s, sp.Interval):
+ (start, start_open), (end, end_open) = sorted(
+ [
+ (1 / s.start, s.left_open),
+ (1 / s.end, s.right_open),
+ ],
+ key=lambda el: el[0],
+ )
+ nset = sp.Interval(
+ start,
+ end,
+ left_open=start_open,
+ right_open=end_open,
+ )
+
+ elif isinstance(s, ComplexRegion):
+ denominator = self.real**2 + self.imag**2
+ nset = BlessedSet.reals_to_complex(
+ self.real / denominator,
+ -self.imag / denominator,
+ ).bset
+
+ elif s is sp.Complexes:
+ nset = sp.Reals
+
+ else:
+ msg = f'abs() not implemented for set {s}'
+ raise ValueError(msg)
+
+ # log.critical(BlessedSet(nset - {0}))
+ return BlessedSet(nset - {0})
+
+ @functools.cached_property
+ def cos(self) -> typ.Self:
+ r"""Compute the `BlessedSet` resulting from applying the reciprocal function $\cos x$ to each element."""
+ # log.critical(self)
+
+ s = self.bset
+ # Trivial
+ if s is sp.EmptySet:
+ nset = sp.EmptySet
+
+ # Points
+ elif isinstance(s, sp.FiniteSet):
+ nset = sp.FiniteSet(*[sp.cos(el) for el in s])
+
+ # Ranges
+ elif isinstance(s, sp.Range):
+ x = sp.Symbol('x', real=True)
+ nset = sp.calculus.util.function_range(
+ sp.cos(x), x, sp.Interval(s.inf, s.sup)
+ )
+ ## TODO: Don't just cast to Interval.
+ ## -- Easily infinite points.
+ ## -- Would need a blessed ConditionSet.
+
+ elif isinstance(s, sp.Interval):
+ x = sp.Symbol('x', real=True)
+ nset = sp.calculus.util.function_range(sp.cos(x), x, s)
+
+ elif isinstance(s, ComplexRegion):
+ ## TODO: cos(x)*cosh(y) - sin(x)*sinh(y)
+ log.error(
+ 'cos(x) not (yet) implemented for complex region: %s. Falling back to sp.Complexes.',
+ str(s),
+ )
+ nset = sp.Complexes
+
+ elif s is sp.Complexes:
+ ## -> Applying cos(x)*cosh(y) - sin(x)*sinh(y) gives all of C.
+ nset = sp.Complexes
+
+ # Unions: Recurse
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(*(BlessedSet(arg).cos.bset for arg in s.args))
+
+ else:
+ msg = f'cos()/sin() not implemented for set {s}'
+ raise TypeError(msg)
+
+ # log.critical(BlessedSet(nset))
+ return BlessedSet(nset)
+
+ @functools.cached_property
+ def sin(self) -> typ.Self:
+ r"""Compute the `BlessedSet` resulting from applying the reciprocal function $\cos x$ to each element."""
+ # log.critical(self)
+
+ return (self - sp.pi / 2).cos
+
+ @functools.cached_property
+ def arctan(self) -> typ.Self:
+ r"""Compute the `BlessedSet` resulting from applying the reciprocal function $\cos x$ to each element."""
+ # log.critical(self)
+
+ s = self.bset
+
+ # Trivial
+ if s is sp.EmptySet:
+ nset = sp.EmptySet
+
+ # Points
+ elif isinstance(s, sp.FiniteSet):
+ nset = sp.FiniteSet(*[sp.arctan(el) for el in s])
+
+ elif isinstance(s, sp.Range):
+ x = sp.Symbol('x', real=True)
+ nset = sp.calculus.util.function_range(
+ sp.arctan(x), x, sp.Interval(s.inf, s.sup)
+ )
+ ## TODO: It's more a bunch of points than really continuous...
+ ## -- FiniteSet might get way too big, though.
+
+ elif isinstance(s, sp.Interval):
+ x = sp.Symbol('x', real=True)
+ nset = sp.calculus.util.function_range(sp.arctan(x), x, s)
+
+ elif isinstance(s, ComplexRegion) or s is sp.Complexes:
+ raise NotImplementedError
+
+ # Unions: Recurse
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(*(BlessedSet(arg).arctan for arg in s.args))
+
+ else:
+ msg = f'arctan() not implemented for set {s}'
+ raise ValueError(msg)
+
+ return BlessedSet(nset)
+
+ @functools.cached_property
+ def sinc(self) -> typ.Self:
+ # log.critical(self)
+
+ s = self.bset
+ if isinstance(s, ComplexRegion) or s is sp.Complexes:
+ ## TODO: Handle better
+ return sp.Complexes
+
+ return BlessedSet(sp.Interval(SINC_MINIMUM, 1))
+
+ @functools.cached_property
+ def arg(self) -> typ.Self:
+ """Compute the set resulting from applying the complex argument to all elements."""
+ # log.critical(self)
+
+ s = self.bset
+
+ # Trivial
+ ## -> arg of 0 is undefined just like zero-division.
+ if s is sp.EmptySet or 0 in s:
+ nset = sp.EmptySet
+
+ # Points
+ if isinstance(s, sp.FiniteSet):
+ nset = sp.FiniteSet(*[sp.arg(el) for el in s])
+
+ elif isinstance(s, sp.Range | sp.Interval):
+ if s.inf > 0 and s.sup > 0:
+ nset = sp.FiniteSet(0)
+ if s.inf < 0 and s.sup < 0:
+ nset = sp.FiniteSet(sp.pi)
+ if s.inf < 0 and s.sup > 0:
+ nset = sp.FiniteSet(0, sp.pi)
+ raise TypeError
+
+ elif isinstance(s, ComplexRegion):
+ _q1_q4 = ComplexRegion(sp.Interval.open(0, sp.oo) * sp.Reals)
+ _q2 = ComplexRegion(sp.Interval.open(-sp.oo, 0) * sp.Interval(0, sp.oo))
+ _q3 = ComplexRegion(
+ sp.Interval.open(-sp.oo, 0) * sp.Interval.open(0, sp.oo)
+ )
+ _pos_im_axis = ComplexRegion(sp.FiniteSet(0) * sp.Interval.open(0, sp.oo))
+ _neg_im_axis = ComplexRegion(sp.FiniteSet(0) * sp.Interval.open(-sp.oo, 0))
+
+ q1_q4 = s & _q1_q4
+ q2 = s & _q2
+ q3 = s & _q3
+ pos_im_axis = s & _pos_im_axis
+ neg_im_axis = s & _neg_im_axis
+
+ nset = (
+ (q1_q4.imag / q1_q4.real).arctan
+ | (q2.imag / q2.real + sp.pi).arctan
+ | (q3.imag / q3.real - sp.pi).arctan
+ | (
+ sp.FiniteSet(sp.EmptySet)
+ if pos_im_axis.is_empty
+ else sp.FiniteSet(sp.pi / 2)
+ )
+ | (
+ sp.FiniteSet(sp.EmptySet)
+ if neg_im_axis.is_empty
+ else sp.FiniteSet(-sp.pi / 2)
+ )
+ ## origin is unimportant here
+ )
+
+ elif s is sp.Complexes:
+ nset = bless_set(sp.Reals)
+
+ # Unions: Recurse
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(*(BlessedSet(arg).arg.bset for arg in s.args))
+
+ else:
+ msg = f'arg() not implemented for set {s}'
+ raise TypeError(msg)
+
+ # log.critical(BlessedSet(nset))
+ return BlessedSet(nset)
+
+ ####################
+ # - Operation w/Scalar
+ ####################
+ @method_lru()
+ def _operate_scalar(self, op: BlessedDomainOp, scalar: Scalar) -> typ.Self: # noqa: C901, PLR0915, PLR0912
+ """Compute the set resulting from applying an operation by a scalar to each set element."""
+ log.critical(['SCALAR', self.bset, op, scalar])
+
+ s = self.bset
+ # Trivial
+ if s is sp.EmptySet:
+ nset = sp.EmptySet
+
+ # Operator-Specific
+ ## -> This is both an optimization, and to protect SetExpr.
+ elif op in [operator.add, operator.sub] and scalar == 0:
+ nset = s
+
+ elif op is operator.mul and scalar == 0:
+ nset = sp.FiniteSet(0)
+
+ elif op is operator.mul and scalar == 1:
+ nset = s
+
+ elif op is operator.truediv and scalar == 0:
+ nset = sp.EmptySet
+
+ elif op is operator.truediv and scalar == 1:
+ nset = s
+
+ elif op is operator.pow and scalar == 0:
+ nset = sp.FiniteSet(1)
+
+ elif op is operator.pow and scalar == 1:
+ nset = s
+
+ # Points
+ ## -> SetExpr works just fine here.
+ elif isinstance(s, sp.FiniteSet):
+ nset = set_expr_op(op, s, sp.FiniteSet(scalar))
+
+ # Range
+ ## -> SetExpr gives up quite thoroughly whenever Range is involved.
+ elif isinstance(s, sp.Range):
+ start, stop = sorted([op(s.start, scalar), op(s.stop, scalar)])
+ if start == stop: # noqa: SIM108
+ nset = sp.FiniteSet(start)
+ else:
+ nset = sp.Range(start, stop, 1)
+
+ # Points
+ ## -> SetExpr w/Reals is just fine.
+ ## -> SetExpr w/Complexes is NOT fine.
+ elif isinstance(s, sp.Interval):
+ if sp.im(scalar) == 0:
+ nset = set_expr_op(op, s, sp.FiniteSet(scalar))
+ else:
+ A = BlessedSet(s)
+ c = sp.re(scalar)
+ d = sp.im(scalar)
+
+ if op in [operator.add, operator.sub]:
+ nset = BlessedSet.reals_to_complex(
+ self._operate_scalar(s, c),
+ sp.FiniteSet(op(0, d)),
+ ).bset
+
+ elif op is operator.mul:
+ nset = BlessedSet.reals_to_complex(
+ A * c,
+ A * d,
+ ).bset
+
+ elif op is operator.truediv:
+ denominator = c**2 + d**2
+ nset = BlessedSet.reals_to_complex(
+ (A * c) / denominator,
+ (A * d) / denominator,
+ ).bset
+
+ elif op is operator.pow:
+ log.error(
+ 'Exponentiation of Intervals by complex numbers is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ nset = sp.Complexes
+
+ elif isinstance(s, ComplexRegion) or s is sp.Complexes:
+ # Extract Complex Elements
+ if s is sp.Complexes:
+ A = BlessedSet(sp.Reals)
+ B = BlessedSet(sp.Reals)
+ else:
+ A = self.real
+ B = self.imag
+
+ if isinstance(scalar, complex):
+ c = sp.re(scalar)
+ d = sp.im(scalar)
+ else:
+ c = scalar
+ d = 0
+
+ # + | -: Seperable
+ if op in [operator.add, operator.sub]:
+ nset = BlessedSet.reals_to_complex(
+ A._operate_scalar(op, scalar), # noqa: SLF001
+ B._operate_scalar(op, scalar), # noqa: SLF001
+ ).bset
+
+ # *: Standard Arithmetic Rules
+ elif op is operator.mul:
+ nset = BlessedSet.reals_to_complex(
+ A * c - B * d,
+ B * c + A * d,
+ ).bset
+
+ # /: Standard Arithmetic Rules
+ elif op is operator.truediv:
+ denominator = c**2 + d**2
+ nset = BlessedSet.reals_to_complex(
+ (A * c + B * d) / denominator,
+ (B * c + A * d) / denominator,
+ ).bset
+
+ # **: Complex Exponentiation
+ ## -> In the generic sense, this is a hell of a function.
+ if op is operator.pow:
+ # Trivial Cases
+ if scalar == 0:
+ _nset = sp.FiniteSet(1)
+
+ elif scalar == 1:
+ _nset = s
+
+ # Extract Absolute Value of Exponent
+ ## -> Later, sign decides whether reciprocal will be applied.
+ ## -> For now, the abs() is what we need to make decisions.
+ abs_scalar = abs(scalar)
+ sgn_scalar = 1 if scalar >= 0 else -1
+
+ # Complex | Integer
+ ## -> Apply De Moivre's Formula
+ if abs_scalar.is_integer:
+ N = int(abs_scalar)
+ r_N = abs(self) ** N
+ arg_N = N * self.arg
+
+ _nset = BlessedSet.reals_to_complex(
+ r_N * arg_N.cos,
+ r_N * arg_N.sin,
+ )
+
+ # Complex | Rational
+ elif isinstance(abs_scalar, Fraction):
+ log.error(
+ 'Exponentiation of set w/rational numbers is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ _nset = sp.Complexes
+
+ # Complex | Real
+ if abs_scalar.is_rational:
+ log.error(
+ 'Exponentiation of set w/real numbers is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ _nset = sp.Complexes
+
+ # Complex | Complex
+ elif d != 0:
+ log.error(
+ 'Exponentiation of set w/complex number is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ _nset = sp.Complexes
+
+ # Deduce Reciprocal (if exponent is negative)
+ if sgn_scalar == 1:
+ nset = _nset
+ nset = BlessedSet(_nset).reciprocal.bset
+
+ # Unions: Recurse
+ elif isinstance(s, sp.Union):
+ nset = sp.Union(
+ *(BlessedSet(arg)._operate_scalar(op, scalar) for arg in s.args) # noqa: SLF001
+ )
+ else:
+ raise NotImplementedError
+
+ log.critical(['SCALAR DONE', nset])
+ return BlessedSet(nset)
+
+ ####################
+ # - Operation w/Other Sets
+ ####################
+ @method_lru()
+ def _operate_minkowski(self, op: BlessedDomainOp, _rhs: typ.Self) -> typ.Self: # noqa: PLR0915, C901, PLR0912
+ """Compute the set resulting from applying an operation by a scalar to each set element."""
+ log.critical(['MINKOWSKI', self.bset, op, _rhs])
+
+ lhs = self.bset
+ if isinstance(_rhs, sp.Set | set): # noqa: SIM108
+ rhs = BlessedSet(_rhs).bset
+ else:
+ rhs = _rhs.bset
+
+ # Trivial
+ if lhs is sp.EmptySet:
+ nset = rhs
+ elif rhs is sp.EmptySet:
+ nset = lhs
+
+ elif op is operator.truediv and not rhs.is_nonzero:
+ nset = sp.EmptySet
+
+ # Unions: Recurse
+ ## -> For narrowing reasons, we do this before other checks.
+ elif isinstance(lhs, sp.Union) and isinstance(rhs, sp.Union):
+ nset = sp.Union(
+ *[
+ BlessedSet(l_arg)._operate_minkowski(op, r_arg) # noqa: SLF001
+ for l_arg, r_arg in itertools.product(lhs.args, rhs.args)
+ ]
+ )
+ elif isinstance(lhs, sp.Union):
+ nset = sp.Union(
+ *[BlessedSet(l_arg)._operate_minkowski(op, rhs) for l_arg in lhs.args] # noqa: SLF001
+ )
+ elif isinstance(rhs, sp.Union):
+ nset = sp.Union(
+ *[BlessedSet(lhs)._operate_minkowski(op, r_arg) for r_arg in rhs.args] # noqa: SLF001
+ )
+
+ # Complex: Recurse
+ ## -> We've eliminated EmptySet, Union.
+ ## -> We're left to account for Complex*, FiniteSet, Range, Interval.
+ elif any(isinstance(s, ComplexRegion) or s is sp.Complexes for s in [lhs, rhs]):
+ if isinstance(lhs, ComplexRegion) or lhs is sp.Complexes:
+ A = BlessedSet(lhs.a_interval)
+ B = BlessedSet(lhs.b_interval)
+ else:
+ A = BlessedSet(lhs)
+ B = BlessedSet(sp.FiniteSet(0))
+
+ if isinstance(rhs, ComplexRegion) or rhs is sp.Complexes:
+ C = BlessedSet(rhs.a_interval)
+ D = BlessedSet(rhs.b_interval)
+ else:
+ C = BlessedSet(rhs)
+ D = BlessedSet(sp.FiniteSet(0))
+
+ # + | -: Seperable
+ ## -> Recursively rely on FiniteSet|Range|Interval implementation.
+ if op in [operator.add, operator.sub]:
+ nset = BlessedSet.reals_to_complex(
+ op(A, C),
+ op(B, D),
+ ).bset
+
+ # *: Standard Arithmetic Rules
+ ## -> Recursively rely on FiniteSet|Range|Interval implementation.
+ elif op is operator.mul:
+ nset = BlessedSet.reals_to_complex(
+ A * C - B * D,
+ B * C + A * D,
+ ).bset
+
+ # /: Standard Arithmetic Rules
+ ## -> Recursively rely on FiniteSet|Range|Interval implementation.
+ elif op is operator.truediv:
+ denominator = C**2 + D**2
+ nset = BlessedSet.reals_to_complex(
+ (A * C + B * D) / denominator,
+ (B * C + A * D) / denominator,
+ ).bset
+
+ # **: Complex Exponentiation
+ ## -> May I just say, "oh boy"...
+ ## -> Decidedly more manual than the others...
+ if op is operator.pow:
+ # Account for FiniteSet
+ ## -> Must be done manually; SetExpr just gives up.
+ if isinstance(rhs, sp.FiniteSet):
+ fs = rhs
+ nset = functools.reduce(
+ lambda l, r: l | r, # noqa: E741
+ (self._operate_scalar(op, fs_el) for fs_el in fs),
+ )
+
+ # Account for ComplexRegion/Complexes | Complex
+ ## -> This is an INVOLVED calculation.
+ ## -> So, we give up.
+ elif sp.FiniteSet(0) != D:
+ log.error(
+ 'Exponentiation of set w/complex set is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ nset = sp.Complexes
+
+ else:
+ pos_C = abs(C) & sp.Interval(0, sp.oo)
+ neg_C = abs(C * -1) & sp.Interval(0, sp.oo)
+
+ halves = []
+ for half_C in [neg_C, pos_C]:
+ if half_C is sp.EmptySet:
+ halves.append(sp.EmptySet)
+
+ # Complex | Integer
+ ## -> Apply De Moivre's Formula
+ elif half_C.issubset(sp.Integers):
+ N = half_C
+ r_N = abs(self) ** N
+ arg_N = N**self.arg
+
+ nset_half = BlessedSet.reals_to_complex(
+ r_N * arg_N.cos,
+ r_N * arg_N.sin,
+ ).bset
+
+ # Complex | Reals
+ elif half_C.issubset(sp.Rationals) or half_C.issubset(sp.Reals):
+ log.error(
+ 'Exponentiation of set w/real or rational set is not (yet) supported; falling back to the entire set of complex numbers'
+ )
+ return BlessedSet(sp.Complexes)
+
+ halves.append(nset_half)
+
+ # Determine Reciprocal
+ nset = (
+ BlessedSet(halves[0]).reciprocal | BlessedSet(halves[1])
+ ).bset
+
+ # Points
+ ## -> (FiniteSet | *) or (* | FiniteSet)
+ ## -> We've eliminated EmptySet, Union, and Complexes|ComplexRegion.
+ ## -> We're left to account for FiniteSet, Range, Interval.
+ elif isinstance(lhs, sp.FiniteSet) or isinstance(rhs, sp.FiniteSet):
+ fs = lhs if isinstance(lhs, sp.FiniteSet) else rhs
+ other = rhs if isinstance(lhs, sp.FiniteSet) else lhs
+
+ if isinstance(other, sp.FiniteSet):
+ nset = set_expr_op(op, fs, other)
+
+ elif isinstance(other, sp.Range):
+ rng = other
+ nset = functools.reduce(
+ lambda l, r: l | r, # noqa: E741
+ (
+ sp.Range(op(rng.inf, fs_el), op(rng.inf, fs_el) + 1, 1)
+ for fs_el in fs
+ ),
+ )
+
+ elif isinstance(other, sp.Interval):
+ nset = set_expr_op(op, lhs, rhs)
+
+ else:
+ raise NotImplementedError
+
+ # Region
+
+ ## -> (Range | *) or (* | Range)
+ ## -> Eliminated EmptySet, Union, Complexes|ComplexRegion, FiniteSet.
+ ## -> We've left to account for Range, Interval.
+ elif isinstance(lhs, sp.Range) or isinstance(rhs, sp.Range):
+ rng = lhs if isinstance(lhs, sp.Range) else rhs
+ other = rhs if isinstance(lhs, sp.Range) else lhs
+
+ if isinstance(other, sp.Range):
+ ## -> Bound scaling is valid for +, -, *, /, **.
+ start, _, _, stop = sorted(
+ [
+ op(lhs.start, rhs.start),
+ op(lhs.start, rhs.stop),
+ op(lhs.stop, rhs.start),
+ op(lhs.stop, rhs.stop),
+ ]
+ )
+ nset = sp.Range(start, stop, 1)
+
+ elif isinstance(other, sp.Interval):
+ ## -> Bound scaling is valid for +, -, *, /, **.
+ itv = other
+ (start, start_open), _, _, (end, end_open) = sorted(
+ [
+ (op(rng.start, itv.start), itv.left_open),
+ (op(rng.start, itv.end), itv.right_open),
+ (op(rng.stop, itv.start), itv.left_open),
+ (op(rng.stop, itv.end), itv.right_open),
+ ],
+ key=lambda el: el[0],
+ )
+ nset = sp.Interval(
+ start,
+ end,
+ left_open=start_open,
+ right_open=end_open,
+ )
+
+ else:
+ raise NotImplementedError
+
+ ## -> Interval | Interval
+ ## -> Eliminated everything but Interval.
+ elif isinstance(lhs, sp.Interval) and isinstance(rhs, sp.Interval):
+ # Edge Case: (0,oo) * (0,oo)
+ if (
+ op is operator.mul
+ and lhs.inf == 0
+ and rhs.inf == 0
+ and lhs.sup == sp.oo
+ and rhs.sup == sp.oo
+ ):
+ nset = sp.Interval(0, sp.oo, left_open=lhs.left_open or rhs.left_open)
+ elif op is not operator.pow:
+ nset = set_expr_op(op, lhs, rhs)
+ else:
+ (start, start_open), _, _, (end, end_open) = sorted(
+ [
+ (lhs.start**rhs.start, lhs.left_open & rhs.left_open),
+ (lhs.start**rhs.end, lhs.left_open & rhs.right_open),
+ (lhs.end**rhs.start, lhs.right_open & rhs.left_open),
+ (lhs.end**rhs.end, lhs.right_open & rhs.right_open),
+ ],
+ key=lambda el: el[0],
+ )
+ nset = sp.Interval(
+ start.n(),
+ end.n(),
+ left_open=start_open,
+ right_open=end_open,
+ )
+ else:
+ raise NotImplementedError
+
+ log.critical(['MINKOWSKI DONE', nset])
+ return BlessedSet(nset)
+
+ ####################
+ # - Operator Overload Dispatch
+ ####################
+ def __add__(self, other: Scalar | typ.Self | sp.Set | set) -> typ.Self:
+ """Deduce the `BlessedSet` resulting from its element-wise addition with a scalar or another `BlessedSet`."""
+ if isinstance(other, BlessedSet | sp.Set | set):
+ return self._operate_minkowski(operator.add, BlessedSet(other))
+ return self._operate_scalar(operator.add, other)
+
+ def __radd__(self, other: Scalar | typ.Self) -> typ.Self:
+ return other + self
+
+ def __sub__(self, other: Scalar | typ.Self) -> typ.Self:
+ """Deduce the `BlessedSet` resulting from its element-wise subtraction with a scalar or another `BlessedSet`."""
+ if isinstance(other, BlessedSet | sp.Set | set):
+ return self._operate_minkowski(operator.sub, BlessedSet(other))
+ return self._operate_scalar(operator.sub, other)
+
+ def __rsub__(self, other: Scalar | typ.Self) -> typ.Self:
+ return other - self
+
+ def __mul__(self, other: Scalar | typ.Self) -> typ.Self:
+ """Deduce the `BlessedSet` resulting from its element-wise multiplication with a scalar or another `BlessedSet`."""
+ if isinstance(other, BlessedSet | sp.Set | set):
+ return self._operate_minkowski(operator.mul, BlessedSet(other))
+ return self._operate_scalar(operator.mul, other)
+
+ def __rmul__(self, other: Scalar | typ.Self) -> typ.Self:
+ return other * self
+
+ def __truediv__(self, other: Scalar | typ.Self) -> typ.Self:
+ """Deduce the `BlessedSet` resulting from its element-wise division with a scalar or another `BlessedSet`."""
+ if isinstance(other, BlessedSet):
+ return self._operate_minkowski(operator.truediv, BlessedSet(other))
+ return self._operate_scalar(operator.truediv, other)
+
+ def __rtruediv__(self, other: Scalar | typ.Self) -> typ.Self:
+ return other / self
+
+ def __pow__(self, other: Scalar | typ.Self) -> typ.Self:
+ """Deduce the `BlessedSet` resulting from its element-wise exponentiation with a scalar or another `BlessedSet`."""
+ if isinstance(other, BlessedSet | sp.Set | set):
+ return self._operate_minkowski(operator.pow, BlessedSet(other))
+ return self._operate_scalar(operator.pow, other)
+
+ def __rpow__(self, other: Scalar | typ.Self) -> typ.Self:
+ return other**self
+
+ def atan2(self, other: Scalar | typ.Self) -> typ.Self:
+ lhs = self.bset
+ if isinstance(other, BlessedSet | sp.Set | set):
+ rhs = BlessedSet(other)
+ else:
+ rhs = other
+
+ if not isinstance(lhs, ComplexRegion) and not isinstance(rhs, ComplexRegion):
+ return BlessedSet.reals_to_complex(lhs, rhs).arg
+
+ raise NotImplementedError
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py
index 3234407..c7c9f73 100644
--- a/src/blender_maxwell/utils/sympy_extra/unit_analysis.py
+++ b/src/blender_maxwell/utils/sympy_extra/unit_analysis.py
@@ -17,19 +17,23 @@
"""Functions for characterizaiton, conversion and casting of `sympy` objects that use units."""
import functools
+import typing as typ
+import jax
import sympy as sp
import sympy.physics.units as spu
+from .. import logger
from . import units as spux
from .parse_cast import sympy_to_python
from .sympy_type import SympyType
+log = logger.get(__name__)
+
####################
# - Unit Characterization
####################
-## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def uses_units(sp_obj: SympyType) -> bool:
"""Determines if an expression uses any units.
@@ -43,7 +47,6 @@ def uses_units(sp_obj: SympyType) -> bool:
return sp_obj.has(spu.Quantity)
-## TODO: Caching w/srepr'ed expression.
## TODO: An LFU cache could do better than an LRU.
def get_units(expr: sp.Expr) -> set[spu.Quantity]:
"""Finds all units used by the expression, and returns them as a set.
@@ -73,6 +76,7 @@ def get_units(expr: sp.Expr) -> set[spu.Quantity]:
####################
# - Dimensional Characterization
####################
+@functools.lru_cache(maxsize=8192)
def unit_dim_to_unit_dim_deps(
unit_dims: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
@@ -106,6 +110,7 @@ def unit_dim_to_unit_dim_deps(
return None
+@functools.lru_cache(maxsize=8192)
def unit_to_unit_dim_deps(
unit: SympyType,
) -> dict[spu.dimensions.Dimension, int] | None:
@@ -130,6 +135,7 @@ def unit_to_unit_dim_deps(
)
+@functools.lru_cache(maxsize=8192)
def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
"""Compare the dimensional dependencies of two unit dimensions.
@@ -140,6 +146,7 @@ def compare_unit_dims(unit_dim_l: SympyType, unit_dim_r: SympyType) -> bool:
)
+@functools.lru_cache(maxsize=8192)
def compare_units_by_unit_dims(unit_l: SympyType, unit_r: SympyType) -> bool:
"""Compare two units by their unit dimensions."""
return unit_to_unit_dim_deps(unit_l) == unit_to_unit_dim_deps(unit_r)
@@ -155,6 +162,7 @@ def compare_unit_dim_to_unit_dim_deps(
####################
# - Unit Casting
####################
+@functools.lru_cache(maxsize=8192)
def strip_units(sp_obj: SympyType) -> SympyType:
"""Strip all units by replacing them to `1`.
@@ -180,6 +188,7 @@ def strip_units(sp_obj: SympyType) -> SympyType:
return sp_obj.subs(spux.UNIT_TO_1)
+@functools.lru_cache(maxsize=8192)
def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType:
"""Convert a sympy object to the given unit.
@@ -193,8 +202,7 @@ def convert_to_unit(sp_obj: SympyType, unit: SympyType | None) -> SympyType:
# raise ValueError(msg)
-## TODO: Include sympy_to_python in 'scale_to' to match semantics of 'scale_to_unit_system'
-## -- Introduce a 'strip_unit
+@functools.lru_cache(maxsize=8192)
def scale_to_unit(
sp_obj: SympyType,
unit: spu.Quantity | None,
@@ -230,8 +238,10 @@ def scale_to_unit(
return sp_obj_stripped
+@functools.lru_cache(maxsize=8192)
def scaling_factor(
- unit_from: SympyType, unit_to: SympyType
+ unit_from: SympyType | None,
+ unit_to: SympyType | None,
) -> int | float | complex | tuple | None:
"""Compute the numerical scaling factor imposed on the unitless part of the expression when converting from one unit to another.
@@ -247,12 +257,35 @@ def scaling_factor(
Raises:
ValueError: If the two units don't share a common dimension.
"""
+ if unit_from is None or unit_to is None:
+ return 1
+
if compare_units_by_unit_dims(unit_from, unit_to):
- return scale_to_unit(unit_from, unit_to)
+ res = strip_units(spu.convert_to(unit_from, unit_to))
+
+ # Ensure Integer w/Length > 10 Doesn't Cause Overflow
+ ## 'jax' insists on float32, and I guess fair enough.
+ ## Problem is, it overflows with integers larger than 10 digits.
+ ## Easy fix. But we want the full precision for easier cases.
+ if isinstance(res, sp.Rational):
+ max_order = sp.ceiling(sp.log(max([res.numerator, res.denominator]), 10))
+ if max_order > 10: # noqa: PLR2004
+ return float(max_order)
+ return res
return None
-@functools.cache
+@functools.lru_cache(maxsize=64)
+def unit_scaling_func_n(
+ unit_from: SympyType | None,
+ unit_to: SympyType | None,
+) -> typ.Callable[[jax.Array], jax.Array]:
+ """Produce a `jax` function that computes the conversion between the two givne units."""
+ a = sp.Symbol('a')
+ return sp.lambdify(a, a * scaling_factor(unit_from, unit_to), 'jax')
+
+
+@functools.lru_cache(maxsize=8192)
def unit_str_to_unit(unit_str: str, optional: bool = False) -> SympyType | None:
"""Determine the `sympy` unit expression that matches the given unit string.
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py
index cb30375..eb7c8f5 100644
--- a/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py
+++ b/src/blender_maxwell/utils/sympy_extra/unit_system_analysis.py
@@ -48,16 +48,20 @@ def strip_unit_system(
def convert_to_unit_system(
- sp_obj: SympyType, unit_system: UnitSystem | None
+ sp_obj: SympyType, unit_system: UnitSystem | None, strip_units: bool = False
) -> SympyType:
"""Convert an expression to the units of a given unit system."""
- if unit_system is None:
- return sp_obj
+ if unit_system is not None:
+ converted_sp_obj = spu.convert_to(
+ sp_obj,
+ {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
+ )
+ else:
+ converted_sp_obj = sp_obj
- return spu.convert_to(
- sp_obj,
- {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
- )
+ if strip_units:
+ return strip_unit_system(converted_sp_obj, unit_system)
+ return converted_sp_obj
####################
diff --git a/src/blender_maxwell/utils/sympy_extra/unit_systems.py b/src/blender_maxwell/utils/sympy_extra/unit_systems.py
index 1de298a..a9b983f 100644
--- a/src/blender_maxwell/utils/sympy_extra/unit_systems.py
+++ b/src/blender_maxwell/utils/sympy_extra/unit_systems.py
@@ -25,6 +25,7 @@ Attributes:
import typing as typ
import sympy.physics.units as spu
+from frozendict import frozendict
from . import units as spux
from .physical_type import PhysicalType as PT # noqa: N817
@@ -33,48 +34,50 @@ from .sympy_expr import Unit
####################
# - Unit System Representation
####################
-UnitSystem: typ.TypeAlias = dict[PT, Unit]
+UnitSystem: typ.TypeAlias = frozendict[PT, Unit]
####################
# - Standard Unit Systems
####################
-UNITS_SI: UnitSystem = {
- PT.NonPhysical: None,
- # Global
- PT.Time: spu.second,
- PT.Angle: spu.radian,
- PT.SolidAngle: spu.steradian,
- PT.Freq: spu.hertz,
- PT.AngFreq: spu.radian * spu.hertz,
- # Cartesian
- PT.Length: spu.meter,
- PT.Area: spu.meter**2,
- PT.Volume: spu.meter**3,
- # Mechanical
- PT.Vel: spu.meter / spu.second,
- PT.Accel: spu.meter / spu.second**2,
- PT.Mass: spu.kilogram,
- PT.Force: spu.newton,
- # Energy
- PT.Work: spu.joule,
- PT.Power: spu.watt,
- PT.PowerFlux: spu.watt / spu.meter**2,
- PT.Temp: spu.kelvin,
- # Electrodynamics
- PT.Current: spu.ampere,
- PT.CurrentDensity: spu.ampere / spu.meter**2,
- PT.Voltage: spu.volt,
- PT.Capacitance: spu.farad,
- PT.Impedance: spu.ohm,
- PT.Conductance: spu.siemens,
- PT.Conductivity: spu.siemens / spu.meter,
- PT.MFlux: spu.weber,
- PT.MFluxDensity: spu.tesla,
- PT.Inductance: spu.henry,
- PT.EField: spu.volt / spu.meter,
- PT.HField: spu.ampere / spu.meter,
- # Luminal
- PT.LumIntensity: spu.candela,
- PT.LumFlux: spux.lumen,
- PT.Illuminance: spu.lux,
-}
+UNITS_SI: UnitSystem = frozendict(
+ {
+ PT.NonPhysical: None,
+ # Global
+ PT.Time: spu.second,
+ PT.Angle: spu.radian,
+ PT.SolidAngle: spu.steradian,
+ PT.Freq: spu.hertz,
+ PT.AngFreq: spu.radian * spu.hertz,
+ # Cartesian
+ PT.Length: spu.meter,
+ PT.Area: spu.meter**2,
+ PT.Volume: spu.meter**3,
+ # Mechanical
+ PT.Vel: spu.meter / spu.second,
+ PT.Accel: spu.meter / spu.second**2,
+ PT.Mass: spu.kilogram,
+ PT.Force: spu.newton,
+ # Energy
+ PT.Work: spu.joule,
+ PT.Power: spu.watt,
+ PT.PowerFlux: spu.watt / spu.meter**2,
+ PT.Temp: spu.kelvin,
+ # Electrodynamics
+ PT.Current: spu.ampere,
+ PT.CurrentDensity: spu.ampere / spu.meter**2,
+ PT.Voltage: spu.volt,
+ PT.Capacitance: spu.farad,
+ PT.Impedance: spu.ohm,
+ PT.Conductance: spu.siemens,
+ PT.Conductivity: spu.siemens / spu.meter,
+ PT.MFlux: spu.weber,
+ PT.MFluxDensity: spu.tesla,
+ PT.Inductance: spu.henry,
+ PT.EField: spu.volt / spu.meter,
+ PT.HField: spu.ampere / spu.meter,
+ # Luminal
+ PT.LumIntensity: spu.candela,
+ PT.LumFlux: spux.lumen,
+ PT.Illuminance: spu.lux,
+ }
+)