feat: extensive improvements

main
Sofus Albert Høgsbro Rose 2024-06-13 16:10:00 +02:00
parent b51c4f1889
commit 81a71b2c47
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
116 changed files with 12336 additions and 5713 deletions

56
TODO.md
View File

@ -1,56 +0,0 @@
# Working TODO
- [x] Wave Constant
- Sources
- [x] Temporal Shapes / Continuous Wave Temporal Shape
- [x] Temporal Shapes / Symbolic Temporal Shape
- [x] Plane Wave Source
- [ ] TFSF Source
- [x] Gaussian Beam Source
- [ ] Astig. Gauss Beam
- Monitors
- [x] EH Field
- [x] Power Flux
- [x] Permittivity
- [ ] Diffraction
- Tidy3D / Integration
- [ ] Exporter
- [ ] Combine
- [ ] Importer
- Sim Grid
- [ ] Sim Grid
- [ ] Auto
- [ ] Manual
- [ ] Uniform
- [ ] Data
- Structures
- [x] Cylinder
- [ ] Cylinder Array
- [ ] L-Cavity Cylinder
- [ ] H-Cavity Cylinder
- [ ] FCC Lattice
- [ ] BCC Lattice
- [ ] Monkey
- Expr Socket
- [x] LVF Mode
- Math Nodes
- [ ] Reduce Math
- [x] Transform Math - reindex freq->wl
- Material Data Fitting
- [ ] Data File Import
- [ ] DataFit Medium
- Mediums
- [ ] Non-Linearities
- [ ] PEC Medium
- [ ] Isotropic Medium
- [ ] Sellmeier Medium
- [ ] Drude Medium
- [ ] Debye Medium
- [ ] Anisotropic Medium
- Integration
- [x] Simulation and Analysis of Maxim's Cavity
- Constants
- [x] Number Constant
- [x] Vector Constant
- [x] Physical Constant
- [x] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`.

View File

@ -6,13 +6,13 @@ authors = [
{ name = "Sofus Albert Høgsbro Rose", email = "blender-maxwell@sofusrose.com" }
]
dependencies = [
"tidy3d>=2.6.3",
"tidy3d==2.7.0rc2",
"pydantic>=2.7.1",
"sympy==1.12",
"scipy==1.12.*",
"trimesh==4.2.*",
"networkx==3.2.*",
"rich==12.5.*",
"rich>=13.7.1",
"rtree==1.2.*",
"jax[cpu]==0.4.26",
"msgspec[toml]==0.18.6",
@ -28,6 +28,8 @@ dependencies = [
"certifi==2021.10.8",
"polars>=0.20.26",
"seaborn[stats]>=0.13.2",
"frozendict>=2.4.4",
"pydantic-tensor>=0.2.0",
]
## When it comes to dev-dep conflicts:
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.

View File

@ -7,13 +7,17 @@
# all-features: false
# with-sources: false
absl-py==2.1.0
# via chex
# via optax
# via orbax-checkpoint
annotated-types==0.6.0
# via pydantic
argcomplete==3.3.0
# via commitizen
boto3==1.23.1
boto3==1.34.123
# via tidy3d
botocore==1.26.10
botocore==1.34.123
# via boto3
# via s3transfer
certifi==2021.10.8
@ -23,7 +27,9 @@ cfgv==3.4.0
charset-normalizer==2.1.1
# via commitizen
# via requests
click==8.0.3
chex==0.1.86
# via optax
click==8.1.7
# via dask
# via tidy3d
cloudpickle==3.0.0
@ -31,8 +37,6 @@ cloudpickle==3.0.0
colorama==0.4.6
# via commitizen
commitizen==3.25.0
commonmark==0.9.1
# via rich
contourpy==1.2.0
# via matplotlib
cycler==0.12.1
@ -43,13 +47,19 @@ decli==0.6.2
# via commitizen
distlib==0.3.8
# via virtualenv
etils==1.9.1
# via orbax-checkpoint
fake-bpy-module-4-0==20231118
filelock==3.14.0
# via virtualenv
flax==0.8.4
# via tidy3d
fonttools==4.49.0
# via matplotlib
frozendict==2.4.4
fsspec==2024.2.0
# via dask
# via etils
h5netcdf==1.0.2
# via tidy3d
h5py==3.10.0
@ -63,38 +73,61 @@ importlib-metadata==6.11.0
# via commitizen
# via dask
# via tidy3d
importlib-resources==6.4.0
# via etils
jax==0.4.26
# via chex
# via flax
# via optax
# via orbax-checkpoint
jaxlib==0.4.26
# via chex
# via jax
# via optax
# via orbax-checkpoint
jaxtyping==0.2.28
jinja2==3.1.3
# via commitizen
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.2
# via tidy3d
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.3
# via seaborn
# via tidy3d
mdurl==0.1.2
# via markdown-it-py
ml-dtypes==0.4.0
# via jax
# via jaxlib
# via tensorstore
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via flax
# via orbax-checkpoint
msgspec==0.18.6
nest-asyncio==1.6.0
# via orbax-checkpoint
networkx==3.2
nodeenv==1.8.0
# via pre-commit
numba==0.59.1
numpy==1.24.3
# via chex
# via contourpy
# via flax
# via h5py
# via jax
# via jaxlib
@ -103,16 +136,24 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
# via optax
# via orbax-checkpoint
# via patsy
# via pydantic-tensor
# via scipy
# via seaborn
# via shapely
# via statsmodels
# via tensorstore
# via tidy3d
# via trimesh
# via xarray
opt-einsum==3.3.0
# via jax
optax==0.2.2
# via flax
orbax-checkpoint==0.5.15
# via flax
packaging==24.0
# via commitizen
# via dask
@ -123,6 +164,7 @@ packaging==24.0
pandas==2.2.1
# via seaborn
# via statsmodels
# via tidy3d
# via xarray
partd==1.4.1
# via dask
@ -136,10 +178,14 @@ polars==0.20.26
pre-commit==3.7.0
prompt-toolkit==3.0.36
# via questionary
protobuf==5.27.1
# via orbax-checkpoint
pydantic==2.7.1
# via pydantic-tensor
# via tidy3d
pydantic-core==2.18.2
# via pydantic
pydantic-tensor==0.2.0
pygments==2.17.2
# via rich
pyjwt==2.8.0
@ -157,6 +203,8 @@ pytz==2024.1
pyyaml==6.0.1
# via commitizen
# via dask
# via flax
# via orbax-checkpoint
# via pre-commit
# via responses
# via tidy3d
@ -167,11 +215,12 @@ requests==2.31.0
# via tidy3d
responses==0.23.1
# via tidy3d
rich==12.5.0
rich==13.7.1
# via flax
# via tidy3d
rtree==1.2.0
ruff==0.4.3
s3transfer==0.5.2
s3transfer==0.10.1
# via boto3
scipy==1.12.0
# via jax
@ -190,9 +239,12 @@ six==1.16.0
statsmodels==0.14.2
# via seaborn
sympy==1.12
tensorstore==0.1.61
# via flax
# via orbax-checkpoint
termcolor==2.4.0
# via commitizen
tidy3d==2.6.3
tidy3d==2.7.0rc2
toml==0.10.2
# via tidy3d
tomli-w==1.0.0
@ -200,6 +252,7 @@ tomli-w==1.0.0
tomlkit==0.12.4
# via commitizen
toolz==0.12.1
# via chex
# via dask
# via partd
trimesh==4.2.0
@ -208,6 +261,10 @@ typeguard==2.13.3
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0
# via chex
# via etils
# via flax
# via orbax-checkpoint
# via pydantic
# via pydantic-core
tzdata==2024.1
@ -223,4 +280,5 @@ wcwidth==0.2.13
xarray==2024.2.0
# via tidy3d
zipp==3.18.0
# via etils
# via importlib-metadata

View File

@ -7,34 +7,44 @@
# all-features: false
# with-sources: false
absl-py==2.1.0
# via chex
# via optax
# via orbax-checkpoint
annotated-types==0.6.0
# via pydantic
boto3==1.23.1
boto3==1.34.123
# via tidy3d
botocore==1.26.10
botocore==1.34.123
# via boto3
# via s3transfer
certifi==2021.10.8
# via requests
charset-normalizer==2.0.10
# via requests
click==8.0.3
chex==0.1.86
# via optax
click==8.1.7
# via dask
# via tidy3d
cloudpickle==3.0.0
# via dask
commonmark==0.9.1
# via rich
contourpy==1.2.0
# via matplotlib
cycler==0.12.1
# via matplotlib
dask==2023.10.1
# via tidy3d
etils==1.9.1
# via orbax-checkpoint
flax==0.8.4
# via tidy3d
fonttools==4.49.0
# via matplotlib
frozendict==2.4.4
fsspec==2024.2.0
# via dask
# via etils
h5netcdf==1.0.2
# via tidy3d
h5py==3.10.0
@ -45,32 +55,55 @@ idna==3.3
importlib-metadata==6.11.0
# via dask
# via tidy3d
importlib-resources==6.4.0
# via etils
jax==0.4.26
# via chex
# via flax
# via optax
# via orbax-checkpoint
jaxlib==0.4.26
# via chex
# via jax
# via optax
# via orbax-checkpoint
jaxtyping==0.2.28
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.2
# via tidy3d
kiwisolver==1.4.5
# via matplotlib
llvmlite==0.42.0
# via numba
locket==1.0.0
# via partd
markdown-it-py==3.0.0
# via rich
matplotlib==3.8.3
# via seaborn
# via tidy3d
mdurl==0.1.2
# via markdown-it-py
ml-dtypes==0.4.0
# via jax
# via jaxlib
# via tensorstore
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via flax
# via orbax-checkpoint
msgspec==0.18.6
nest-asyncio==1.6.0
# via orbax-checkpoint
networkx==3.2
numba==0.59.1
numpy==1.24.3
# via chex
# via contourpy
# via flax
# via h5py
# via jax
# via jaxlib
@ -79,16 +112,24 @@ numpy==1.24.3
# via ml-dtypes
# via numba
# via opt-einsum
# via optax
# via orbax-checkpoint
# via patsy
# via pydantic-tensor
# via scipy
# via seaborn
# via shapely
# via statsmodels
# via tensorstore
# via tidy3d
# via trimesh
# via xarray
opt-einsum==3.3.0
# via jax
optax==0.2.2
# via flax
orbax-checkpoint==0.5.15
# via flax
packaging==24.0
# via dask
# via h5netcdf
@ -98,6 +139,7 @@ packaging==24.0
pandas==2.2.1
# via seaborn
# via statsmodels
# via tidy3d
# via xarray
partd==1.4.1
# via dask
@ -106,10 +148,14 @@ patsy==0.5.6
pillow==10.2.0
# via matplotlib
polars==0.20.26
protobuf==5.27.1
# via orbax-checkpoint
pydantic==2.7.1
# via pydantic-tensor
# via tidy3d
pydantic-core==2.18.2
# via pydantic
pydantic-tensor==0.2.0
pygments==2.17.2
# via rich
pyjwt==2.8.0
@ -126,17 +172,20 @@ pytz==2024.1
# via pandas
pyyaml==6.0.1
# via dask
# via flax
# via orbax-checkpoint
# via responses
# via tidy3d
requests==2.27.1
requests==2.32.3
# via responses
# via tidy3d
responses==0.23.1
# via tidy3d
rich==12.5.0
rich==13.7.1
# via flax
# via tidy3d
rtree==1.2.0
s3transfer==0.5.2
s3transfer==0.10.1
# via boto3
scipy==1.12.0
# via jax
@ -153,12 +202,16 @@ six==1.16.0
statsmodels==0.14.2
# via seaborn
sympy==1.12
tidy3d==2.6.3
tensorstore==0.1.61
# via flax
# via orbax-checkpoint
tidy3d==2.7.0rc2
toml==0.10.2
# via tidy3d
tomli-w==1.0.0
# via msgspec
toolz==0.12.1
# via chex
# via dask
# via partd
trimesh==4.2.0
@ -167,6 +220,10 @@ typeguard==2.13.3
types-pyyaml==6.0.12.20240311
# via responses
typing-extensions==4.10.0
# via chex
# via etils
# via flax
# via orbax-checkpoint
# via pydantic
# via pydantic-core
tzdata==2024.1
@ -178,4 +235,5 @@ urllib3==1.26.8
xarray==2024.2.0
# via tidy3d
zipp==3.18.0
# via etils
# via importlib-metadata

Binary file not shown.

View File

@ -34,6 +34,7 @@ from .bl import (
KeymapItemDef,
ManagedObjName,
PresetName,
PropName,
SocketName,
)
from .bl_types import BLEnumStrEnum
@ -64,6 +65,7 @@ __all__ = [
'KeymapItemDef',
'ManagedObjName',
'PresetName',
'PropName',
'SocketName',
'BLEnumStrEnum',
'BLInstance',

View File

@ -21,8 +21,9 @@ import bpy
####################
# - Blender Strings
####################
BLEnumID = str
SocketName = str
BLEnumID: typ.TypeAlias = str
SocketName: typ.TypeAlias = str
PropName: typ.TypeAlias = str
####################
# - Blender Enums

View File

@ -44,6 +44,10 @@ class OperatorType(enum.StrEnum):
# Node: ExportDataFile
NodeExportDataFile = enum.auto()
# Node: PoleResidueMediumNode
NodeFitDispersiveMedium = enum.auto()
NodeReleaseDispersiveFit = enum.auto()
# Node: Tidy3DWebImporter
NodeLoadCloudSim = enum.auto()

View File

@ -34,6 +34,7 @@ from blender_maxwell.contracts import (
OperatorType,
PanelType,
PresetName,
PropName,
SocketName,
addon,
)
@ -62,8 +63,12 @@ from .sim_types import (
BoundCondType,
DataFileFormat,
NewSimCloudTask,
Realization,
RealizationScalar,
SimAxisDir,
SimFieldPols,
SimMetadata,
SimRealizations,
SimSpaceAxis,
manual_amp_time,
)
@ -92,6 +97,7 @@ __all__ = [
'OperatorType',
'PanelType',
'PresetName',
'PropName',
'SocketName',
'addon',
'Icon',
@ -107,8 +113,12 @@ __all__ = [
'BoundCondType',
'DataFileFormat',
'NewSimCloudTask',
'Realization',
'RealizationScalar',
'SimAxisDir',
'SimFieldPols',
'SimMetadata',
'SimRealizations',
'SimSpaceAxis',
'manual_amp_time',
'NodeCategory',

View File

@ -14,63 +14,101 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import enum
import functools
import typing as typ
from fractions import Fraction
import bpy
import jax
import numpy as np
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .socket_types import SocketType
from .unit_systems import UNITS_BLENDER
log = logger.get(__name__)
BL_SOCKET_DESCR_ANNOT_STRING = ':: '
@dataclasses.dataclass(kw_only=True, frozen=True)
class BLSocketInfo:
class BLSocketInfo(pyd.BaseModel):
"""Parsed information fully representing a Blender interface socket, enabling translation of values to a format accepted by the Blender socket.
Notes:
All Blender socket values are considered to implicitly be in the unit system `ct.UNITS_BLENDER`.
For this reason, we do not declare or handle units / unit systems at the interface.
This design also prevents spurious (slow) unit conversions in the value-encoding hot path.
Attributes:
has_support: Whether our system explicitly supports mapping to/from this Blender socket.
is_preview: Whether this Blender socket is only relevant for driving a preview, not for functionality.
socket_type: The corresponding socket type in our system, which the Blender socket can have data pushed from.
size: Identifier for the scalar/1D shape of the Blender socket, impacting if/how values can be pushed.
mathtype: The mathematical type of the Blender socket value.
physical_type: The unit dimension of the Blender socket value.
default_value: The Blender socket's own default value, used as the initial socket value in our system.
bl_isocket_identifier: The internal identifier of the Blender interface socket in the node tree.
This is the **only way** to read/write a particular instance of the socket value through eg. a GeoNodes modifier.
"""
model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
default_value: typ.Any | None
has_support: bool
is_preview: bool
socket_type: SocketType | None
size: spux.NumberSize1D | None
mathtype: spux.MathType | None
physical_type: spux.PhysicalType | None
default_value: spux.ScalarUnitlessRealExpr
bl_isocket_identifier: spux.ScalarUnitlessRealExpr
bl_isocket_identifier: str
@functools.cached_property
def unit(self) -> spux.Unit | None:
"""Deduce the unit of the Blender socket value, by retrieving the Blender units of `self.physical_type`.
If the socket has no unit dimension, this will also be `None`.
"""
if self.physical_type is not None:
return UNITS_BLENDER[self.physical_type]
return None
def encode(
self, raw_value: typ.Any, unit_system: spux.UnitSystem | None
self, raw_value: int | Fraction | float | tuple[int | Fraction | float, ...]
) -> typ.Any:
"""Encode a raw value, given a unit system, to be directly writable to a node socket.
"""Conform a mostly-prepared value, so as to be guaranteed-writable to a node socket.
This encoded form is also guaranteed to support writing to a node socket via a modifier interface.
"""
# Non-Numerical: Passthrough
if unit_system is None or self.physical_type is None:
MT = spux.MathType
# Numerical: Conform to Pure Python Type
if self.mathtype is not None:
# Coerce jax/np Array -> lists
if isinstance(raw_value, np.ndarray | jax.Array):
if self.mathtype is MT.Real and isinstance(raw_value.item(0), int):
return raw_value.astype(float).flatten().tolist()
return raw_value.flatten().tolist()
# Coerce int -> float w/Target is Real
## -> The value - modifier - GN path is more strict than properties.
if self.mathtype is MT.Real and isinstance(raw_value, int):
return float(raw_value)
# Coerce Fraction -> tuple[int, int] w/Target is Rational
if self.mathtype is MT.Rational and isinstance(raw_value, Fraction):
return (raw_value.numerator, raw_value.denominator)
return raw_value
# Numerical: Convert to Pure Python Type
if (
unit_system is not None
and self.physical_type is not spux.PhysicalType.NonPhysical
):
unitless_value = spux.scale_to_unit_system(raw_value, unit_system)
elif isinstance(raw_value, spux.SympyType):
unitless_value = spux.sympy_to_python(raw_value)
else:
unitless_value = raw_value
# Coerce int -> float w/Target is Real
## -> The value - modifier - GN path is more strict than properties.
if self.mathtype is spux.MathType.Real and isinstance(unitless_value, int):
return float(unitless_value)
return unitless_value
# Non-Numerical: Passthrough
return raw_value
class BLSocketType(enum.StrEnum):
@ -403,12 +441,12 @@ class BLSocketType(enum.StrEnum):
## -> The description hint "2D" is the trigger for this.
## -> Ignore the last component to get the effect of "2D".
elif description.startswith('2D'):
default_value = sp.ImmutableMatrix(tuple(bl_default_value)[:2])
default_value = tuple(bl_default_value)[:2]
# 3D/4D: Simple Parse to Sympy Matrix
## -> We don't explicitly check the size.
else:
default_value = sp.ImmutableMatrix(tuple(bl_default_value))
default_value = tuple(bl_default_value)
else:
# Non-Mathematical: Passthrough

View File

@ -14,18 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from .category_types import NodeCategory as NC
from .category_types import NodeCategory as NC # noqa: N817
NODE_CAT_LABELS = {
# Analysis/
NC.MAXWELLSIM_ANALYSIS: 'Analysis',
NC.MAXWELLSIM_ANALYSIS_MATH: 'Math',
# Utilities/
NC.MAXWELLSIM_UTILITIES: 'Utilities',
# Inputs/
NC.MAXWELLSIM_INPUTS: 'Inputs',
NC.MAXWELLSIM_INPUTS_SCENE: 'Scene',
NC.MAXWELLSIM_INPUTS_CONSTANTS: 'Constants',
NC.MAXWELLSIM_INPUTS_FILEIMPORTERS: 'File Importers',
NC.MAXWELLSIM_INPUTS_WEBIMPORTERS: 'Web Importers',
# Solvers/
NC.MAXWELLSIM_SOLVERS: 'Solvers',
# Outputs/
NC.MAXWELLSIM_OUTPUTS: 'Outputs',
NC.MAXWELLSIM_OUTPUTS_FILEEXPORTERS: 'File Exporters',
@ -39,15 +43,11 @@ NODE_CAT_LABELS = {
# Structures/
NC.MAXWELLSIM_STRUCTURES: 'Structures',
NC.MAXWELLSIM_STRUCTURES_PRIMITIVES: 'Primitives',
# Bounds/
NC.MAXWELLSIM_BOUNDS: 'Bounds',
NC.MAXWELLSIM_BOUNDS_BOUNDCONDS: 'Conds',
# Monitors/
NC.MAXWELLSIM_MONITORS: 'Monitors',
NC.MAXWELLSIM_MONITORS_PROJECTED: 'Projected',
# Simulations/
NC.MAXWELLSIM_SIMS: 'Simulations',
NC.MAXWELLSIM_SIMS_SIMGRIDAXES: 'Sim Grid Axes',
# Utilities/
NC.MAXWELLSIM_UTILITIES: 'Utilities',
NC.MAXWELLSIM_SIMS_BOUNDCONDFACES: 'BC Faces',
NC.MAXWELLSIM_SIMS_SIMGRIDAXES: 'Grid Axes',
}

View File

@ -27,6 +27,9 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_ANALYSIS = enum.auto()
MAXWELLSIM_ANALYSIS_MATH = enum.auto()
# Utilities/
MAXWELLSIM_UTILITIES = enum.auto()
# Inputs/
MAXWELLSIM_INPUTS = enum.auto()
MAXWELLSIM_INPUTS_SCENE = enum.auto()
@ -34,6 +37,9 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_INPUTS_FILEIMPORTERS = enum.auto()
MAXWELLSIM_INPUTS_WEBIMPORTERS = enum.auto()
# Solvers/
MAXWELLSIM_SOLVERS = enum.auto()
# Outputs/
MAXWELLSIM_OUTPUTS = enum.auto()
MAXWELLSIM_OUTPUTS_FILEEXPORTERS = enum.auto()
@ -51,21 +57,15 @@ class NodeCategory(blender_type_enum.BlenderTypeEnum):
MAXWELLSIM_STRUCTURES = enum.auto()
MAXWELLSIM_STRUCTURES_PRIMITIVES = enum.auto()
# Bounds/
MAXWELLSIM_BOUNDS = enum.auto()
MAXWELLSIM_BOUNDS_BOUNDCONDS = enum.auto()
# Monitors/
MAXWELLSIM_MONITORS = enum.auto()
MAXWELLSIM_MONITORS_PROJECTED = enum.auto()
# Simulations/
MAXWELLSIM_SIMS = enum.auto()
MAXWELLSIM_SIMS_BOUNDCONDFACES = enum.auto()
MAXWELLSIM_SIMS_SIMGRIDAXES = enum.auto()
# Utilities/
MAXWELLSIM_UTILITIES = enum.auto()
@classmethod
def get_tree(cls):
## TODO: Refactor

View File

@ -14,16 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import functools
import io
import typing as typ
import jax
import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.jaxarray import JaxArrayBytes
from blender_maxwell.utils.lru_method import method_lru
log = logger.get(__name__)
@ -40,16 +46,28 @@ class ArrayFlow(pyd.BaseModel):
None if unitless.
"""
model_config = pyd.ConfigDict(frozen=True, arbitrary_types_allowed=True)
model_config = pyd.ConfigDict(frozen=True)
values: jtyp.Inexact[jtyp.Array, '...'] ## TODO: Custom field type
unit: spux.Unit | None = None
is_sorted: bool = False
####################
# - Array Access
####################
jax_bytes: JaxArrayBytes ## Immutable jax.Array, anyone? :)
@functools.cached_property
def values(self) -> jax.Array:
"""Return the jax array."""
with io.BytesIO() as memfile:
memfile.write(base64.b64decode(self.jax_bytes.decode('utf-8')))
memfile.seek(0)
return jnp.load(memfile)
####################
# - Computed Properties
####################
@method_lru()
def __len__(self) -> int:
"""Outer length of the contained array."""
return len(self.values)
@ -70,7 +88,7 @@ class ArrayFlow(pyd.BaseModel):
####################
# - Array Features
####################
@property
@functools.cached_property
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
"""Standardized access to `self.values`."""
return self.values
@ -80,6 +98,13 @@ class ArrayFlow(pyd.BaseModel):
"""Shape of the contained array."""
return self.values.shape
@method_lru(maxsize=32)
def _getitem_index(self, i: int) -> typ.Self | spux.SympyExpr:
value = self.values[i]
if len(value.shape) == 0:
return value * self.unit if self.unit is not None else sp.S(value)
return ArrayFlow(jax_bytes=value, unit=self.unit, is_sorted=self.is_sorted)
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
"""Implement indexing and slicing in a sane way.
@ -88,16 +113,13 @@ class ArrayFlow(pyd.BaseModel):
"""
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
jax_bytes=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
if isinstance(subscript, int):
value = self.values[subscript]
if len(value.shape) == 0:
return value * self.unit if self.unit is not None else sp.S(value)
return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
return self._getitem_index(subscript)
raise NotImplementedError
@ -121,7 +143,7 @@ class ArrayFlow(pyd.BaseModel):
## -> However, too-large ints may cause JAX to suffer from an overflow.
## -> Jax works in 32-bit domain by default, for performance.
## -> While it can be adjusted, that would also have tradeoffs.
## -> Instead, a quick .n() turns all the big-ints into floats.
## -> Instead, a quick float() turns all the big-ints into floats.
## -> Not super satisfying, but hey - it's all numerical anyway.
a = self.mathtype.sp_symbol_a
rescale_expr = (
@ -134,11 +156,12 @@ class ArrayFlow(pyd.BaseModel):
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
jax_bytes=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)
# @method_lru()
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
"""Find the index of the value that is closest to the given value.
@ -159,26 +182,27 @@ class ArrayFlow(pyd.BaseModel):
# BinSearch for "Right IDX"
## >>> self.values[right_idx] > scaled_value
## >>> self.values[right_idx - 1] < scaled_value
right_idx = np.searchsorted(self.values, scaled_value, side='left')
right_idx = jnp.searchsorted(self.values, scaled_value, side='left')
# Case: Right IDX is Boundary
if right_idx == 0:
return right_idx
return int(right_idx)
if right_idx == len(self.values):
return right_idx - 1
return int(right_idx - 1)
# Find Closest of [Right IDX - 1, Right IDX]
left_val = self.values[right_idx - 1]
right_val = self.values[right_idx]
if (scaled_value - left_val) <= (right_val - scaled_value):
return right_idx - 1
return int(right_idx - 1)
return right_idx
return int(right_idx)
####################
# - Unit Transforms
####################
@method_lru()
def correct_unit(self, unit: spux.Unit) -> typ.Self:
"""Simply replace the existing unit with the given one.
@ -186,8 +210,9 @@ class ArrayFlow(pyd.BaseModel):
corrected_unit: The new unit to insert.
**MUST** be associable with a well-defined `PhysicalType`.
"""
return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
return ArrayFlow(jax_bytes=self.values, unit=unit, is_sorted=self.is_sorted)
@method_lru()
def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
"""Rescale the `ArrayFlow` to be expressed in the given unit.

View File

@ -19,8 +19,8 @@ import functools
import typing as typ
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.staticproperty import staticproperty
log = logger.get(__name__)
@ -189,26 +189,32 @@ class FlowKind(enum.StrEnum):
####################
# - Class Methods
####################
## TODO: Remove this (only events uses it).
@classmethod
def scale_to_unit_system(
cls,
kind: typ.Self,
flow_obj: spux.SympyExpr,
self,
flow: typ.Any,
unit_system: spux.UnitSystem,
use_jax_array: bool = True,
):
# log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
## TODO: Use a hot-path logger.
if kind == FlowKind.Value:
return spux.scale_to_unit_system(
flow_obj,
unit_system,
)
if kind == FlowKind.Range:
return flow_obj.rescale_to_unit_system(unit_system)
"""Perform unit-system scaling per-`FlowKind`."""
match self:
case FlowKind.Value if isinstance(spux.SympyType):
return spux.scale_to_unit_system(
flow,
unit_system,
use_jax_array=use_jax_array,
)
if kind == FlowKind.Params:
return flow_obj.rescale_to_unit_system(unit_system)
case FlowKind.Array | FlowKind.Range:
return flow.rescale_to_unit_system(unit_system)
msg = 'Tried to scale unknown kind'
case FlowKind.Func:
return flow.scale_to_unit_system(unit_system)
case FlowKind.Params:
return flow
case FlowKind.Info:
return flow.scale_to_unit_system(unit_system)
msg = f"Applying unit-system scaling to {self} doesn't make sense"
raise ValueError(msg)

View File

@ -18,15 +18,19 @@ import dataclasses
import functools
import typing as typ
import pydantic as pyd
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import FrozenDict, frozendict
from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .lazy_range import RangeFlow
log = logger.get(__name__)
LabelArray: typ.TypeAlias = list[str]
LabelArray: typ.TypeAlias = tuple[str, ...]
# IndexArray: Identifies Discrete Dimension Values
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
@ -36,8 +40,7 @@ LabelArray: typ.TypeAlias = list[str]
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
@dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow:
class InfoFlow(pyd.BaseModel):
"""Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
@ -60,13 +63,20 @@ class InfoFlow:
- **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
"""
model_config = pyd.ConfigDict(frozen=True)
####################
# - Dimensions: Covariant Index
####################
dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
default_factory=dict
dims: FrozenDict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
default_factory=frozendict
)
@functools.cached_property
def dims_list(self) -> list[IndexArray]:
"""Return the dimensional symbols as an ordered list."""
return list(self.dims.keys())
# Access
@functools.cached_property
def first_dim(self) -> sim_symbols.SimSymbol | None:
@ -88,12 +98,7 @@ class InfoFlow:
return list(self.dims.keys())[-1]
return None
def dim_by_idx(self, idx: int) -> sim_symbols.SimSymbol | None:
"""Retrieve the dimension associated with a particular index."""
if idx > 0 and idx < len(self.dims) - 1:
return list(self.dims.keys())[idx]
return None
@method_lru()
def dim_by_name(self, dim_name: str, optional: bool = False) -> int | None:
"""The integer axis occupied by the dimension.
@ -110,6 +115,7 @@ class InfoFlow:
raise ValueError(msg)
# Information By-Dim
@method_lru()
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the dim's index is continuous, and therefore index array.
@ -120,16 +126,19 @@ class InfoFlow:
"""
return self.dims[dim] is None
@method_lru()
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
@method_lru()
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the (int) dim is indexed by a `LabelArray`."""
if dim.mathtype is spux.MathType.Integer:
return isinstance(self.dims[dim], list)
return isinstance(self.dims[dim], tuple)
return False
@method_lru()
def is_idx_uniform(self, dim: sim_symbols.SimSymbol) -> bool:
"""Whether the given dim has explicitly uniform indexing.
@ -141,6 +150,7 @@ class InfoFlow:
dim_idx = self.dims[dim]
return isinstance(dim_idx, RangeFlow) and dim_idx.scaling == 'lin'
@method_lru()
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
"""The integer axis occupied by the dimension.
@ -158,8 +168,8 @@ class InfoFlow:
####################
## -> Whenever a dimension is deleted, we retain what that index value was.
## -> This proves to be very helpful for clear visualization.
pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
default_factory=dict
pinned_values: FrozenDict[sim_symbols.SimSymbol, spux.SympyExpr] = (
dataclasses.field(default_factory=frozendict)
)
####################
@ -188,7 +198,7 @@ class InfoFlow:
Notes:
Corresponds to `len(raw_data.shape)`, if `raw_data` is the n-dimensional array corresponding to this `InfoFlow`.
"""
return len(self.dims) + self.output_shape_len
return len(self.dims) + self.output.shape_len
@functools.cached_property
def is_scalar(self) -> tuple[spux.MathType, int, int]:
@ -216,6 +226,31 @@ class InfoFlow:
####################
# - Operations: Comparison
####################
# def broadcast(self, other: typ.Self) -> bool:
# """Deduce the output `InfoFlow` by deriving numpy's broadcasting rules for `InfoFlow`, enabling operations between compatible but differently sized `InfoFlow`s."""
# # Broadcast Dimensions
# new_dims = {}
# for _dim_l, _dim_r in itertools.zip_longest(
# reversed(self.dims.keys()),
# reversed(other.dims.keys()),
# fillvalue=None,
# ):
# dim_l = _dim_l if _dim_l is not None else _dim_r
# dim_r = _dim_r if _dim_r is not None else _dim_l
# if dim_l.compare(dim_r) and self.dims[dim_l] == other.dims[dim_r]:
# new_dims |= {dim_l: self.dims[dim_l]}
# # Broadcast Symbols
# return InfoFlow(
# dims=new_dims,
# output=self.output,
# pinned_values=self.pinned_values | other.pinned_values,
# )
# return tuple(reversed(new_shape))
@method_lru()
def compare_dims_identical(self, other: typ.Self) -> bool:
"""Whether that the quantity and properites of all dimension `SimSymbol`s are "identical".
@ -226,6 +261,7 @@ class InfoFlow:
for dim_l, dim_r in zip(self.dims, other.dims, strict=True)
)
@method_lru()
def compare_addable(
self, other: typ.Self, allow_differing_unit: bool = False
) -> bool:
@ -239,6 +275,7 @@ class InfoFlow:
other.output, allow_differing_unit=allow_differing_unit
)
@method_lru()
def compare_multiplicable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be multiplied (elementwise).
@ -251,6 +288,7 @@ class InfoFlow:
or self.compare_dims_identical(other)
)
@method_lru()
def compare_exponentiable(self, other: typ.Self) -> bool:
"""Whether the two `InfoFlow`s can be exponentiated.
@ -265,36 +303,54 @@ class InfoFlow:
or self.compare_dims_identical(other)
)
####################
# - Operations: General Update
####################
@method_lru(maxsize=256)
def _update(self, frozen_kwargs: frozendict) -> typ.Self:
if not frozen_kwargs:
return self
return InfoFlow(**(dict(self) | frozen_kwargs))
def update(self, **kwargs: dict) -> typ.Self:
return self._update(frozendict(kwargs))
####################
# - Operations: Dimensions
####################
@method_lru()
def prepend_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index 0."""
return InfoFlow(
dims={dim: dim_idx} | self.dims,
output=self.output,
pinned_values=self.pinned_values,
)
return self.update(dims=frozendict({dim: dim_idx} | self.dims))
@method_lru()
def append_dim(
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
) -> typ.Self:
"""Insert a new dimension at index -1."""
return self.update(dims=frozendict(self.dims | {dim: dim_idx}))
@method_lru()
def slice_dim(
self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
) -> typ.Self:
"""Slice a dimensional array by-index along a particular dimension."""
return InfoFlow(
dims={
_dim: (
dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else dim_idx
)
for _dim, dim_idx in self.dims.items()
},
output=self.output,
pinned_values=self.pinned_values,
return self.update(
dims=frozendict(
{
_dim: (
dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
if _dim == dim
else dim_idx
)
for _dim, dim_idx in self.dims.items()
}
)
)
@method_lru()
def replace_dim(
self,
old_dim: sim_symbols.SimSymbol,
@ -302,48 +358,66 @@ class InfoFlow:
new_dim_idx: IndexArray,
) -> typ.Self:
"""Replace a dimension entirely, in-place, including symbol and index array."""
return InfoFlow(
dims={
(new_dim if _dim == old_dim else _dim): (
new_dim_idx if _dim == old_dim else dim_idx
)
for _dim, dim_idx in self.dims.items()
},
output=self.output,
pinned_values=self.pinned_values,
return self.update(
dims=frozendict(
{
(new_dim if _dim == old_dim else _dim): (
new_dim_idx if _dim == old_dim else dim_idx
)
for _dim, dim_idx in self.dims.items()
}
)
)
@method_lru()
def replace_dims(
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges."""
return InfoFlow(
dims={
dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
},
output=self.output,
pinned_values=self.pinned_values,
return self.update(
dims=frozendict(
{
dim: new_dims.get(dim, dim_idx)
for dim, dim_idx in self.dim_idx.items()
}
)
)
@method_lru()
def delete_dim(
self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
) -> typ.Self:
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
new_pin = (
{dim_to_remove: self.dims[dim_to_remove][pin_idx]}
if pin_idx is not None
else {}
)
return InfoFlow(
dims={
dim: dim_idx
for dim, dim_idx in self.dims.items()
if dim != dim_to_remove
},
output=self.output,
pinned_values=self.pinned_values | new_pin,
)
if dim_to_remove in self.dims:
dim_idx = self.dims[dim_to_remove]
# Deduce Pinned Value
if pin_idx is not None:
pin_value = {dim_to_remove: dim_idx[pin_idx]}
elif (
self.has_idx_discrete(dim_to_remove)
or self.has_idx_labels(dim_to_remove)
) and len(dim_idx) == 1:
pin_value = {dim_to_remove: dim_idx[0]}
else:
pin_value = {}
# Delete Dimension
return self.update(
dims=frozendict(
{
dim: dim_idx
for dim, dim_idx in self.dims.items()
if dim != dim_to_remove
}
),
pinned_values=self.pinned_values | pin_value,
)
msg = 'Dimension to delete is not in the InfoFlow dimensions (to delete: {dim_to_remove}, info={self})'
raise ValueError(msg)
@method_lru()
def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
"""Swap the positions of two dimensions."""
@ -357,10 +431,10 @@ class InfoFlow:
swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
return InfoFlow(
dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
output=self.output,
pinned_values=self.pinned_values,
return self.update(
dims=frozendict(
{dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys}
),
)
####################
@ -368,33 +442,58 @@ class InfoFlow:
####################
def update_output(self, **kwargs) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
return InfoFlow(
dims=self.dims,
return self.update(
output=self.output.update(**kwargs),
pinned_values=self.pinned_values,
)
def operate_output(
self,
other: typ.Self,
op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
) -> spux.SympyExpr:
"""Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
sym_name = sim_symbols.SimSymbolName.Expr
expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
## TODO: Handle per-cell matrix units?
# def map_output(
# self,
# op: typ.Callable[[spux.SympyExpr], spux.SympyExpr],
# unit_op: typ.Callable[[spux.SympyExpr], spux.SympyExpr],
# new_domain: sp.Set | None = None,
# ) -> spux.SympyExpr:
# """Apply an operation to a single `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
# sym_name = sim_symbols.SimSymbolName.Expr
# expr = op(self.output.sp_symbol_phy)
# unit_expr = unit_op(self.output.unit_factor)
return InfoFlow(
dims=self.dims,
output=sim_symbols.SimSymbol.from_expr(sym_name, expr, unit_expr),
pinned_values=self.pinned_values,
# return self.update(
# output=sim_symbols.SimSymbol.from_expr(
# sym_name, expr, unit_expr, new_domain=new_domain
# ),
# )
def scale_to_unit_system(self, unit_system: spux.UnitSystem) -> typ.Self:
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
return self.update(
output=self.output.scale_to_unit_system(unit_system),
)
# def operate_output(
# self,
# other: typ.Self,
# op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
# unit_op: typ.Callable[[spux.SympyExpr, spux.SympyExpr], spux.SympyExpr],
# new_domain: spux.BlessedSet | None = None,
# ) -> spux.SympyExpr:
# """Apply an operation between two the values and units of two `InfoFlow`s by reconstructing the properties of the new output `SimSymbol`."""
# sym_name = sim_symbols.SimSymbolName.Expr
# expr = op(self.output.sp_symbol_phy, other.output.sp_symbol_phy)
# unit_expr = unit_op(self.output.unit_factor, other.output.unit_factor)
# return self.update(
# output=sim_symbols.SimSymbol.from_expr(
# sym_name,
# expr,
# unit_expr,
# new_domain=new_domain,
# )
# )
####################
# - Operations: Fold
####################
@functools.cached_property
def fold_last_input(self):
"""Fold the last input dimension into the output."""
last_idx = self.dims[self.last_dim]
@ -411,12 +510,13 @@ class InfoFlow:
case (_, _):
raise NotImplementedError ## Not yet :)
return InfoFlow(
dims={
dim: dim_idx
for dim, dim_idx in self.dims.items()
if dim != self.last_dim
},
return self.update(
dims=frozendict(
{
dim: dim_idx
for dim, dim_idx in self.dims.items()
if dim != self.last_dim
}
),
output=new_output,
pinned_values=self.pinned_values,
)

View File

@ -226,8 +226,10 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import FrozenDict, frozendict
from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .info import InfoFlow
@ -256,8 +258,10 @@ class FuncFlow(pyd.BaseModel):
model_config = pyd.ConfigDict(frozen=True)
func: LazyFunction
func_args: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
func_kwargs: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
func_args: tuple[sim_symbols.SimSymbol, ...] = ()
func_kwargs: FrozenDict[str, sim_symbols.SimSymbol] = pyd.Field(
default_factory=frozendict
)
func_output: sim_symbols.SimSymbol | None = None
supports_jax: bool = False
@ -319,12 +323,11 @@ class FuncFlow(pyd.BaseModel):
####################
# - Realization
####################
@method_lru(maxsize=64)
def realize(
self,
params: ParamsFlow,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
disallow_jax: bool = True,
) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
@ -337,22 +340,21 @@ class FuncFlow(pyd.BaseModel):
"""
if self.supports_jax and not disallow_jax:
return self.func_jax(
*params.scaled_func_args(symbol_values),
**params.scaled_func_kwargs(symbol_values),
*params.scaled_func_args(self.func_args, symbol_values),
**params.scaled_func_kwargs(self.func_kwargs, symbol_values),
)
return self.func(
*params.scaled_func_args(symbol_values),
**params.scaled_func_kwargs(symbol_values),
*params.scaled_func_args(self.func_args, symbol_values),
**params.scaled_func_kwargs(self.func_kwargs, symbol_values),
)
@method_lru(maxsize=64)
def realize_as_data(
self,
info: InfoFlow,
params: ParamsFlow,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> dict[sim_symbols.SimSymbol, jtyp.Inexact[jtyp.Array, '...']]:
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> frozendict[sim_symbols.SimSymbol, jtyp.Inexact[jtyp.Array, '...']]:
"""Realize as an ordered dictionary mapping each realized `self.dims` entry, with the last entry containing all output data as mapped from the `self.output`."""
data = {}
for dim, dim_idx in info.dims.items():
@ -386,9 +388,14 @@ class FuncFlow(pyd.BaseModel):
if info.has_idx_labels(dim):
data |= {dim: dim_idx}
return data | {info.output: self.realize(params, symbol_values=symbol_values)}
return frozendict(
data
| {info.output: self.realize(params, symbol_values=symbol_values)}
| {'pinned': info.pinned_values}
)
def realize_partial(
@method_lru(maxsize=64)
def generate_realizer(
self, params: ParamsFlow
) -> typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
@ -399,7 +406,7 @@ class FuncFlow(pyd.BaseModel):
The units/types/shape/etc. of the returned numerical type conforms to the `SimSymbol` specification of relevant `self.func_args` entries and `self.func_output`.
This function should be used whenever the unrealized result of a `FuncFlow` needs to be used as the argument to another `FuncFlow`.
By using `realize_partial()`, two things are ensured:
By using `generate_realizer()`, two things are ensured:
- Since the function defined in `.compose_within()` must be purely numerical, the usual `.realize()` mechanism can't be used to sweep away the pre-realized symbol values.
- Since this `FuncFlow` is completely consumed, with no symbols / arguments / etc. explicitly surviving, its impact on the data flow can be considered to have been effectively terminated after using this function.
@ -418,24 +425,36 @@ class FuncFlow(pyd.BaseModel):
return self.func(
*[
func_arg_n(*sym_args, *pre_realized_syms)
for func_arg_n in params.func_args_n
for func_arg_n in params.func_args_n(self.func_args)
],
**{
func_arg_name: func_kwarg_n(*sym_args, *pre_realized_syms)
for func_arg_name, func_kwarg_n in params.func_kwargs_n.items()
for func_arg_name, func_kwarg_n in params.func_kwargs_n(
self.func_kwargs
).items()
},
)
return realizer
@method_lru(maxsize=8)
def realize_preview(
self, params: ParamsFlow
) -> typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
jtyp.Inexact[jtyp.Array, '...'],
]:
"""Realize the function value, by replacing unknown symbols with their declared preview values."""
return self.realize(params, symbol_values=params.symbol_preview_values)
####################
# - Operations
####################
def compose_within(
self,
enclosing_func: LazyFunction,
enclosing_func_args: list[sim_symbols.SimSymbol] = (),
enclosing_func_kwargs: dict[str, sim_symbols.SimSymbol] = MappingProxyType({}),
enclosing_func_args: tuple[sim_symbols.SimSymbol, ...] = (),
enclosing_func_kwargs: frozendict[str, sim_symbols.SimSymbol] = frozendict(),
enclosing_func_output: sim_symbols.SimSymbol | None = None,
supports_jax: bool = False,
) -> typ.Self:
@ -480,18 +499,19 @@ class FuncFlow(pyd.BaseModel):
return FuncFlow(
func=lambda *args, **kwargs: enclosing_func(
self.func(
*list(args[: len(self.func_args)]),
*args[: len(self.func_args)],
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
),
func_args=self.func_args + list(enclosing_func_args),
func_args=self.func_args + tuple(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
func_output=enclosing_func_output,
supports_jax=self.supports_jax and supports_jax,
)
@method_lru(maxsize=8)
def __or__(
self,
other: typ.Self,
@ -532,7 +552,7 @@ class FuncFlow(pyd.BaseModel):
def self_func(args, kwargs):
ret = self.func(
*list(args[: len(self.func_args)]),
*args[: len(self.func_args)],
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
)
if not self.is_concatenated:
@ -543,7 +563,7 @@ class FuncFlow(pyd.BaseModel):
func=lambda *args, **kwargs: (
*self_func(args, kwargs),
other.func(
*list(args[len(self.func_args) :]),
*args[len(self.func_args) :],
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
@ -553,27 +573,15 @@ class FuncFlow(pyd.BaseModel):
is_concatenated=True,
)
@method_lru(maxsize=16)
def scale_to_unit(self, unit: spux.Unit | None = None) -> typ.Self:
"""Encloses this function in a unit-converting function, whose output is a converted, unitless scalar.
`unit` must be manually guaranteed to be compatible with `self.unit`.
"""
if self.func_output is not None:
# Retrieve Output Unit
output_unit = self.func_output.unit
# Compile Efficient Unit-Conversion Function
a = self.func_output.mathtype.sp_symbol_a
unit_convert_expr = (
spux.scale_to_unit(a * output_unit, unit)
if self.func_output.unit is not None
else a
)
unit_convert_func = sp.lambdify(a, unit_convert_expr.n(), 'jax')
# Compose Unit-Converted FuncFlow
return self.compose_within(
enclosing_func=unit_convert_func,
enclosing_func=spux.unit_scaling_func_n(self.func_output.unit, unit),
supports_jax=True,
enclosing_func_output=self.func_output.update(unit=unit),
)
@ -581,6 +589,7 @@ class FuncFlow(pyd.BaseModel):
msg = f'Tried to scale a FuncFlow to a unit system, but it has no tracked output SimSymbol. ({self})'
raise ValueError(msg)
@method_lru(maxsize=16)
def scale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:

View File

@ -26,6 +26,8 @@ import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import frozendict
from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
@ -47,6 +49,10 @@ class ScalingMode(enum.StrEnum):
@staticmethod
def to_name(v: typ.Self) -> str:
"""Friendly, single-letter, human-readable column names.
Must be concise, as there is not a lot of header space to contain these.
"""
SM = ScalingMode
return {
SM.Lin: 'Linear',
@ -56,6 +62,7 @@ class ScalingMode(enum.StrEnum):
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
@ -135,6 +142,7 @@ class RangeFlow(pyd.BaseModel):
msg = f'RangeFlow is incompatible with SimSymbol {sym}'
raise ValueError(msg)
@method_lru()
def to_sym(
self,
sym_name: sim_symbols.SimSymbolName,
@ -154,6 +162,16 @@ class RangeFlow(pyd.BaseModel):
cols=1,
).set_domain(start=self.realize_start(), end=self.realize_end())
@functools.cached_property
def symbolic_set(self) -> spux.BlessedSet:
if not self.symbols:
if self.steps == 0:
return spux.BlessedSet(sp.Interval(self.start, self.stop))
return spux.BlessedSet(sp.Range(self.start, self.stop + 1))
msg = 'Cant deduce symbolic set from symbolic RangeFlow'
raise ValueError(msg)
####################
# - Symbols
####################
@ -250,6 +268,7 @@ class RangeFlow(pyd.BaseModel):
####################
# - Methods
####################
@functools.lru_cache(maxsize=16)
@staticmethod
def try_from_array(
array: ArrayFlow, uniformity_tolerance: float = 1e-9
@ -269,8 +288,8 @@ class RangeFlow(pyd.BaseModel):
diffs = jnp.diff(array.values)
if (
jnp.all(jnp.abs(diffs - diffs[0]) < uniformity_tolerance)
and len(array.values) > 2 # noqa: PLR2004
len(array.values) > 2 # noqa: PLR2004
and jnp.all(jnp.abs(diffs - diffs[0]) < uniformity_tolerance)
and array.values[0] < array.values[-1]
and array.is_sorted
):
@ -317,6 +336,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
@method_lru(maxsize=16)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
@ -468,12 +488,11 @@ class RangeFlow(pyd.BaseModel):
####################
# - Realization
####################
@method_lru(maxsize=16)
def realize_symbols(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
) -> dict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> frozendict[sp.Symbol, spux.ScalarUnitlessRealExpr]:
"""Realize **all** input symbols to the `RangeFlow`.
Parameters:
@ -501,35 +520,32 @@ class RangeFlow(pyd.BaseModel):
raise NotImplementedError(msg)
realized_syms |= {sym: v}
return realized_syms
return frozendict(realized_syms)
msg = f'RangeFlow: Not all symbols were given a value during realization (symbols={self.symbols}, symbol_values={symbol_values})'
raise ValueError(msg)
@method_lru(maxsize=16)
def realize_start(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the start-bound by inserting particular values for each symbol."""
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.start.subs(realized_symbols))
@method_lru(maxsize=16)
def realize_stop(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
realized_symbols = self.realize_symbols(symbol_values)
return spux.sympy_to_python(self.stop.subs(realized_symbols))
@method_lru(maxsize=16)
def realize_step_size(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> int | float | complex:
"""Realize the stop-bound by inserting particular values for each symbol."""
if self.scaling is not ScalingMode.Lin:
@ -544,11 +560,10 @@ class RangeFlow(pyd.BaseModel):
return int(raw_step_size)
return raw_step_size
@method_lru(maxsize=16)
def realize(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> ArrayFlow:
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
@ -561,7 +576,7 @@ class RangeFlow(pyd.BaseModel):
## TODO: Check symbol values for coverage.
return ArrayFlow(
values=self.as_func(
jax_bytes=self.as_func(
*[
spux.scale_to_unit_system(symbol_values[sym])
for sym in self.sorted_symbols
@ -584,7 +599,7 @@ class RangeFlow(pyd.BaseModel):
msg = f'RangeFlow: Cannot use ".realize_array" when symbols are defined (symbols={self.symbols}, RangeFlow={self}'
raise ValueError(msg)
@property
@functools.cached_property
def values(self) -> jtyp.Inexact[jtyp.Array, '...']:
"""Alias for `realize_array.values`."""
return self.realize_array.values
@ -622,6 +637,7 @@ class RangeFlow(pyd.BaseModel):
####################
# - Units
####################
@method_lru(maxsize=32)
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
"""Replaces the unit without rescaling the unitless bounds.
@ -643,6 +659,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
@method_lru(maxsize=32)
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
"""Replaces the unit, **with** rescaling of the bounds.
@ -673,6 +690,7 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
@method_lru(maxsize=32)
def rescale_to_unit_system(
self, unit_system: spux.UnitSystem | None = None
) -> typ.Self:

View File

@ -23,8 +23,10 @@ import jaxtyping as jtyp
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import FrozenDict, frozendict
from blender_maxwell.utils.lru_method import method_lru
from .array import ArrayFlow
from .expr_info import ExprInfo
@ -42,16 +44,16 @@ class ParamsFlow(pyd.BaseModel):
model_config = pyd.ConfigDict(frozen=True)
arg_targets: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
kwarg_targets: dict[str, sim_symbols.SimSymbol] = pyd.Field(default_factory=dict)
func_args: list[spux.SympyExpr] = pyd.Field(default_factory=list)
func_kwargs: dict[str, spux.SympyExpr] = pyd.Field(default_factory=dict)
func_args: tuple[spux.SympyExpr, ...] = pyd.Field(default_factory=tuple)
func_kwargs: FrozenDict[str, spux.SympyExpr] = pyd.Field(default_factory=frozendict)
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
realized_symbols: dict[
realized_symbols: FrozenDict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = pyd.Field(default_factory=dict)
] = pyd.Field(default_factory=frozendict)
previewed_symbols: FrozenDict[sim_symbols.SimSymbol, spux.SympyExpr] = pyd.Field(
default_factory=frozendict
)
####################
# - Symbols
@ -101,9 +103,10 @@ class ParamsFlow(pyd.BaseModel):
####################
# - JIT'ed Callables for Numerical Function Arguments
####################
@functools.cached_property
@method_lru()
def func_args_n(
self,
arg_targets: tuple[sim_symbols.SimSymbol, ...],
) -> list[
typ.Callable[
[int | float | complex | jtyp.Inexact[jtyp.Array, '...'], ...],
@ -112,7 +115,7 @@ class ParamsFlow(pyd.BaseModel):
]:
"""Callable functions for evaluating each `self.func_args` entry numerically.
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in `self.target_syms`.
Before simplification, each `self.func_args` entry will be conformed to the corresponding (by-index) `SimSymbol` in the passed `arg_targets`.
Notes:
Before using any `sympy` expressions as arguments to the returned callablees, they **must** be fully conformed and scaled to the corresponding `self.symbols` entry using that entry's `SimSymbol.scale()` method.
@ -125,14 +128,13 @@ class ParamsFlow(pyd.BaseModel):
target_sym.conform(func_arg, strip_unit=True),
'jax',
)
for func_arg, target_sym in zip(
self.func_args, self.arg_targets, strict=True
)
for func_arg, target_sym in zip(self.func_args, arg_targets, strict=True)
]
@functools.cached_property
@method_lru()
def func_kwargs_n(
self,
kwarg_targets: frozendict[str, sim_symbols.SimSymbol],
) -> dict[
str,
typ.Callable[
@ -148,7 +150,7 @@ class ParamsFlow(pyd.BaseModel):
return {
key: sp.lambdify(
self.all_sorted_sp_symbols,
self.kwarg_targets[key].conform(func_arg, strip_unit=True),
kwarg_targets[key].conform(func_arg, strip_unit=True),
'jax',
)
for key, func_arg in self.func_kwargs.items()
@ -157,9 +159,24 @@ class ParamsFlow(pyd.BaseModel):
####################
# - Realization
####################
@functools.cached_property
def symbol_preview_values(
self,
) -> frozendict[sim_symbols.SimSymbol, sp.Basic] | None:
"""Provide a dictionary for simplifying all unrealized symbols to their preview values.
Returns `None` if not all unrealized symbols have preview values.
"""
if all(sym.preview_value is not None for sym in self.all_sorted_symbols):
return frozendict(
{sym: sym.preview_value_phy for sym in self.all_sorted_symbols}
)
return None
@method_lru()
def realize_symbols(
self,
symbol_values: dict[
symbol_values: frozendict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
] = MappingProxyType({}),
allow_partial: bool = False,
@ -209,11 +226,11 @@ class ParamsFlow(pyd.BaseModel):
####################
# - Realize Arguments
####################
@method_lru(maxsize=8)
def scaled_func_args(
self,
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
arg_targets: tuple[sim_symbols.SimSymbol, ...],
symbol_values: frozendict[sim_symbols.SimSymbol, spux.SympyExpr] = frozendict(),
) -> list[
int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
@ -237,15 +254,20 @@ class ParamsFlow(pyd.BaseModel):
Parameters:
symbol_values: Particular values for all symbols in `self.symbols`, which will be conformed and used to compute the function arguments (before they are conformed to `self.target_syms`).
"""
realized_symbols = list(
realized_symbols = tuple(
self.realize_symbols(symbol_values | self.realized_symbols).values()
)
return [func_arg_n(*realized_symbols) for func_arg_n in self.func_args_n]
return [
func_arg_n(*realized_symbols)
for func_arg_n in self.func_args_n(arg_targets)
]
@method_lru(maxsize=8)
def scaled_func_kwargs(
self,
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
) -> dict[
kwarg_targets: frozendict[str, sim_symbols.SimSymbol],
symbol_values: frozendict[spux.Symbol, spux.SympyExpr] = frozendict(),
) -> frozendict[
str, int | float | Fraction | float | complex | jtyp.Shaped[jtyp.Array, '...']
]:
"""Realize correctly conformed numerical arguments for `self.func_kwargs`.
@ -254,14 +276,19 @@ class ParamsFlow(pyd.BaseModel):
"""
realized_symbols = self.realize_symbols(symbol_values | self.realized_symbols)
return {
func_arg_name: func_kwarg_n(**realized_symbols)
for func_arg_name, func_kwarg_n in self.func_kwargs_n.items()
}
return frozendict(
{
func_arg_name: func_kwarg_n(**realized_symbols)
for func_arg_name, func_kwarg_n in self.func_kwargs_n(
kwarg_targets
).items()
}
)
####################
# - Operations
####################
@method_lru(maxsize=8)
def __or__(
self,
other: typ.Self,
@ -272,34 +299,30 @@ class ParamsFlow(pyd.BaseModel):
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
"""
return ParamsFlow(
arg_targets=self.arg_targets + other.arg_targets,
kwarg_targets=self.kwarg_targets | other.kwarg_targets,
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
symbols=self.symbols | other.symbols,
realized_symbols=self.realized_symbols | other.realized_symbols,
)
@method_lru(maxsize=8)
def compose_within(
self,
enclosing_arg_targets: list[sim_symbols.SimSymbol] = (),
enclosing_kwarg_targets: list[sim_symbols.SimSymbol] = (),
enclosing_func_args: list[spux.SympyExpr] = (),
enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}),
enclosing_func_args: tuple[spux.SympyExpr, ...] = (),
enclosing_func_kwargs: frozendict[str, spux.SympyExpr] = frozendict(),
enclosing_symbols: frozenset[sim_symbols.SimSymbol] = frozenset(),
) -> typ.Self:
return ParamsFlow(
arg_targets=self.arg_targets + list(enclosing_arg_targets),
kwarg_targets=self.kwarg_targets | dict(enclosing_kwarg_targets),
func_args=self.func_args + list(enclosing_func_args),
func_args=self.func_args + tuple(enclosing_func_args),
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
symbols=self.symbols | enclosing_symbols,
realized_symbols=self.realized_symbols,
)
@method_lru(maxsize=8)
def realize_partial(
self,
symbol_values: dict[
symbol_values: frozendict[
sim_symbols.SimSymbol, spux.SympyExpr | RangeFlow | ArrayFlow
],
) -> typ.Self:
@ -316,11 +339,9 @@ class ParamsFlow(pyd.BaseModel):
Raises:
ValueError: If any symbol in `symbol_values`
"""
syms = set(symbol_values.keys())
syms = frozenset(symbol_values.keys())
if syms.issubset(self.symbols) or not syms:
return ParamsFlow(
arg_targets=self.arg_targets,
kwarg_targets=self.kwarg_targets,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
symbols=self.symbols - syms,
@ -333,7 +354,7 @@ class ParamsFlow(pyd.BaseModel):
# - Generate ExprSocketDef
####################
@functools.cached_property
def sym_expr_infos(self) -> dict[str, ExprInfo]:
def sym_expr_infos(self) -> frozendict[str, ExprInfo]:
"""Generate keyword arguments for defining all `ExprSocket`s needed to realize all `self.symbols`.
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.

View File

@ -31,9 +31,10 @@ class FlowSignal(enum.StrEnum):
"""
NoFlow = enum.auto()
FlowInitializing = enum.auto()
FlowPending = enum.auto()
NoFlow = enum.auto()
# FlowError = enum.auto()
@classmethod
def all(cls) -> set[typ.Self]:

View File

@ -33,8 +33,12 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
ReduceMath = enum.auto()
TransformMath = enum.auto()
# Inputs
# Utilities
Combine = enum.auto()
WaveConstant = enum.auto()
ViewText = enum.auto()
# Inputs
Scene = enum.auto()
## Inputs / Constants
ExprConstant = enum.auto()
@ -48,6 +52,11 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
DataFileImporter = enum.auto()
Tidy3DFileImporter = enum.auto()
# Solvers
FDTDSolver = enum.auto()
ModeSolver = enum.auto()
EMESolver = enum.auto()
# Outputs
Viewer = enum.auto()
## Outputs / File Exporters
@ -82,7 +91,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
DebyeMedium = enum.auto()
## Mediums / Non-Linearities
AddNonLinearity = enum.auto()
ChiThreeSusceptibilityNonLinearity = enum.auto()
ChiThreeSuscepNonLinearity = enum.auto()
TwoPhotonAbsorptionNonLinearity = enum.auto()
KerrNonLinearity = enum.auto()
@ -97,13 +106,6 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
CylinderStructure = enum.auto()
PolySlabStructure = enum.auto()
# Bounds
BoundConds = enum.auto()
## Bounds / Bound Conds
PMLBoundCond = enum.auto()
BlochBoundCond = enum.auto()
AdiabAbsorbBoundCond = enum.auto()
# Monitors
EHFieldMonitor = enum.auto()
PowerFluxMonitor = enum.auto()
@ -115,15 +117,16 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
KSpaceNearFieldProjectionMonitor = enum.auto()
# Sims
Combine = enum.auto()
SimDomain = enum.auto()
FDTDSim = enum.auto()
SimGrid = enum.auto()
## Sims / Sim Grid Axis
AutomaticSimGridAxis = enum.auto()
BoundConds = enum.auto()
## Sims / Bound Conds
PMLBoundCond = enum.auto()
BlochBoundCond = enum.auto()
AdiabAbsorbBoundCond = enum.auto()
## Sims / Grid Axes
AutoSimGridAxis = enum.auto()
ManualSimGridAxis = enum.auto()
UniformSimGridAxis = enum.auto()
ArraySimGridAxis = enum.auto()
# Utilities
Separate = enum.auto()

View File

@ -18,19 +18,22 @@
import dataclasses
import enum
import functools
import typing as typ
from fractions import Fraction
from pathlib import Path
import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import polars as pl
import pydantic as pyd
import tidy3d as td
from blender_maxwell.contracts import BLEnumElement
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .flow_kinds.info import InfoFlow
@ -668,3 +671,56 @@ class DataFileFormat(enum.StrEnum):
E.Txt: True, ## Use # Comments
E.TxtGz: True, ## Same as Txt
}[self]
####################
# - Encode/Decode Metadata
####################
RealizationScalar: typ.TypeAlias = int | float
Realization: typ.TypeAlias = (
RealizationScalar
| tuple[RealizationScalar, ...]
| tuple[tuple[RealizationScalar, ...], ...]
)
class SimRealizations(pyd.BaseModel):
"""Encodes the realized values of symbols that were used to generate a particular simulation."""
model_config = pyd.ConfigDict(frozen=True)
syms: tuple[sim_symbols.SimSymbol, ...] = ()
vals: tuple[Realization, ...] = ()
@pyd.model_validator(mode='after')
def syms_vals_eq_length(self) -> typ.Self:
"""Ensure that `self.syms` and `self.vals` are of equal length."""
if len(self.syms) != len(self.vals):
msg = f"'syms' and 'vals' are of differing length (syms={self.syms}, vals={self.vals})"
raise ValueError(msg)
return self
class SimMetadata(pyd.BaseModel):
"""Encodes simulation metadata."""
model_config = pyd.ConfigDict(frozen=True)
sim_metadata_version: str = '0.1.0'
realizations: SimRealizations = SimRealizations()
@staticmethod
def from_sim(sim: td.Simulation | td.SimulationData) -> typ.Self:
"""Deduce simulation metadata from a simulation / simulation data."""
if 'sim_metadata_version' in sim.attrs:
## TODO: Semantic versioning comparison
return SimMetadata(**sim.attrs)
return SimMetadata()
@functools.cached_property
def syms_vals(
self,
) -> tuple[tuple[sim_symbols.SimSymbol, ...], tuple[Realization, ...]]:
"""Deduce simulation metadata from a simulation / simulation data."""
return (self.realizations.syms, self.realizations.vals)

View File

@ -27,6 +27,7 @@ Attributes:
import typing as typ
import sympy.physics.units as spu
from frozendict import frozendict
from blender_maxwell.utils import sympy_extra as spux
@ -34,48 +35,52 @@ from blender_maxwell.utils import sympy_extra as spux
# - Unit Systems
####################
_PT: typ.TypeAlias = spux.PhysicalType
UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
# Global
_PT.Time: spu.picosecond,
_PT.Freq: spux.terahertz,
_PT.AngFreq: spu.radian * spux.terahertz,
# Cartesian
_PT.Length: spu.micrometer,
_PT.Area: spu.micrometer**2,
_PT.Volume: spu.micrometer**3,
# Energy
_PT.PowerFlux: spu.watt / spu.um**2,
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Mechanical
_PT.Vel: spu.um / spu.second,
_PT.Accel: spu.um / spu.second,
_PT.Mass: spu.microgram,
_PT.Force: spux.micronewton,
# Luminal
# Optics
} ## TODO: Load (dynamically?) from addon preferences
UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | frozendict(
{
# Global
_PT.Time: spu.picosecond,
_PT.Freq: spux.terahertz,
_PT.AngFreq: spu.radian * spux.terahertz,
# Cartesian
_PT.Length: spu.micrometer,
_PT.Area: spu.micrometer**2,
_PT.Volume: spu.micrometer**3,
# Energy
_PT.PowerFlux: spu.watt / spu.um**2,
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Mechanical
_PT.Vel: spu.um / spu.second,
_PT.Accel: spu.um / spu.second,
_PT.Mass: spu.microgram,
_PT.Force: spux.micronewton,
# Luminal
# Optics
}
) ## TODO: Load (dynamically?) from addon preferences
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
# Global
# Cartesian
_PT.Length: spu.um,
_PT.Area: spu.um**2,
_PT.Volume: spu.um**3,
# Mechanical
_PT.Vel: spu.um / spu.second,
_PT.Accel: spu.um / spu.second,
# Energy
_PT.PowerFlux: spu.watt / spu.um**2,
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Luminal
# Optics
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
}
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | frozendict(
{
# Global
# Cartesian
_PT.Length: spu.um,
_PT.Area: spu.um**2,
_PT.Volume: spu.um**3,
# Mechanical
_PT.Vel: spu.um / spu.second,
_PT.Accel: spu.um / spu.second,
# Energy
_PT.PowerFlux: spu.watt / spu.um**2,
# Electrodynamics
_PT.CurrentDensity: spu.ampere / spu.um**2,
_PT.Conductivity: spu.siemens / spu.um,
_PT.EField: spu.volt / spu.um,
_PT.HField: spu.ampere / spu.um,
# Luminal
# Optics
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
}
)

View File

@ -18,6 +18,7 @@ import contextlib
import bmesh
import bpy
import jax
import numpy as np
from blender_maxwell.utils import logger
@ -127,8 +128,18 @@ class ManagedBLMesh(base.ManagedObj):
####################
# - BLMesh Management
####################
def bl_object(self, location: tuple[float, float, float] = (0, 0, 0)):
"""Returns the managed blender object."""
def bl_object(
self, location: np.ndarray | jax.Array | tuple[float, float, float] = (0, 0, 0)
):
"""Returns the managed blender object, centered at the given location."""
if isinstance(location, np.ndarray | jax.Array):
if isinstance(location.item(0), int):
center = tuple(location.astype(float).flatten().tolist())
else:
center = tuple(location.flatten().tolist())
else:
center = tuple([float(el) for el in location])
# Create Object w/Appropriate Data Block
if not (bl_object := bpy.data.objects.get(self.name)):
log.info(
@ -143,7 +154,7 @@ class ManagedBLMesh(base.ManagedObj):
)
managed_collection().objects.link(bl_object)
for i, coord in enumerate(location):
for i, coord in enumerate(center):
if bl_object.location[i] != coord:
bl_object.location[i] = coord

View File

@ -19,8 +19,9 @@
import typing as typ
import bpy
import jax
import numpy as np
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from .. import bl_socket_map
@ -41,13 +42,10 @@ class ModifierAttrsNODES(typ.TypedDict):
Attributes:
node_group: The GeoNodes group to use in the modifier.
unit_system: The unit system used by the GeoNodes output.
Generally, `ct.UNITS_BLENDER` is a good choice.
inputs: Values to associate with each GeoNodes interface socket name.
"""
node_group: bpy.types.GeometryNodeTree
unit_system: UnitSystem
inputs: dict[ct.SocketName, typ.Any]
@ -122,7 +120,6 @@ def write_modifier_geonodes(
## -> TODO: A special case isn't clean enough.
bl_modifier[iface_id] = socket_infos[socket_name].encode(
raw_value=modifier_attrs['inputs'][socket_name],
unit_system=modifier_attrs['unit_system'],
)
modifier_altered = True
## TODO: More fine-grained alterations?
@ -230,7 +227,7 @@ class ManagedBLModifier(base.ManagedObj):
self,
modifier_type: ct.BLModifierType,
modifier_attrs: ModifierAttrs,
location: tuple[float, float, float] = (0, 0, 0),
location: np.ndarray | jax.Array | tuple[float, float, float] = (0, 0, 0),
):
"""Creates a new modifier for the current `bl_object`.

View File

@ -0,0 +1,186 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Declares `ManagedBLText`."""
import time
import typing as typ
import bpy
import matplotlib.axis as mpl_ax
import numpy as np
from blender_maxwell.utils import image_ops, logger
from .. import contracts as ct
from . import base
log = logger.get(__name__)
AREA_TYPE = 'IMAGE_EDITOR'
SPACE_TYPE = 'IMAGE_EDITOR'
####################
# - Managed BL Image
####################
class ManagedBLText(base.ManagedObj):
"""Represents a Blender Image datablock, encapsulating various useful interactions with it.
Attributes:
name: The name of the image.
"""
managed_obj_type = ct.ManagedObjType.ManagedBLText
_bl_image_name: str
def __init__(self, name: str, prev_name: str | None = None):
if prev_name is not None:
self._bl_image_name = prev_name
else:
self._bl_image_name = name
self.name = name
@property
def name(self):
return self._bl_image_name
@name.setter
def name(self, value: str):
log.info(
'Changing ManagedBLText from "%s" to "%s"',
self.name,
value,
)
existing_bl_image = bpy.data.images.get(self.name)
# No Existing Image: Set Value to Name
if existing_bl_image is None:
self._bl_image_name = value
# Existing Image: Rename to New Name
else:
existing_bl_image.name = value
self._bl_image_name = value
# Check: Blender Rename -> Synchronization Error
## -> We can't do much else than report to the user & free().
if existing_bl_image.name != self._bl_image_name:
log.critical(
'BLImage: Failed to set name of %s to %s, as %s already exists.'
)
self._bl_image_name = existing_bl_image.name
self.free()
def free(self):
bl_image = bpy.data.images.get(self.name)
if bl_image is not None:
log.debug('Freeing ManagedBLText "%s"', self.name)
bpy.data.images.remove(bl_image)
####################
# - Managed Object Management
####################
def bl_image(
self,
width_px: int,
height_px: int,
color_model: typ.Literal['RGB', 'RGBA'],
dtype: typ.Literal['uint8', 'float32'],
):
"""Returns the managed blender image.
If the requested image properties are different from the image's, then delete the old image make a new image with correct geometry.
"""
channels = 4 if color_model == 'RGBA' else 3
# Remove Image (if mismatch)
bl_image = bpy.data.images.get(self.name)
if bl_image is not None and (
bl_image.size[0] != width_px
or bl_image.size[1] != height_px
or bl_image.channels != channels
or bl_image.is_float ^ (dtype == 'float32')
):
self.free()
# Create Image w/Geometry (if none exists)
bl_image = bpy.data.images.get(self.name)
if bl_image is None:
bl_image = bpy.data.images.new(
self.name,
width=width_px,
height=height_px,
float_buffer=dtype == 'float32',
)
# Enable Fake User
bl_image.use_fake_user = True
return bl_image
####################
# - Editor UX Manipulation
####################
@classmethod
def preview_area(cls) -> bpy.types.Area | None:
"""Deduces a Blender UI area that can be used for image preview.
Returns:
A Blender UI area, if an appropriate one is visible; else `None`,
"""
valid_areas = [
area for area in bpy.context.screen.areas if area.type == AREA_TYPE
]
if valid_areas:
return valid_areas[0]
return None
@classmethod
def preview_space(cls) -> bpy.types.SpaceProperties | None:
"""Deduces a Blender UI space, within `self.preview_area`, that can be used for image preview.
Returns:
A Blender UI space within `self.preview_area`, if it isn't None; else, `None`.
"""
preview_area = cls.preview_area()
if preview_area is not None:
valid_spaces = [
space for space in preview_area.spaces if space.type == SPACE_TYPE
]
if valid_spaces:
return valid_spaces[0]
return None
return None
####################
# - Methods
####################
def bl_select(self) -> None:
"""Selects the image by loading it into an on-screen UI area/space.
Notes:
The image must already be available, else nothing will happen.
"""
bl_image = bpy.data.images.get(self.name)
if bl_image is not None:
self.preview_space().image = bl_image
@classmethod
def hide_preview(cls) -> None:
"""Deselects the image by loading `None` into the on-screen UI area/space."""
cls.preview_space().image = None

View File

@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import functools
import typing as typ
import jax.lax as jlax
@ -36,6 +37,8 @@ class FilterOperation(enum.StrEnum):
PinLen1: Remove a len(1) dimension.
Pin: Remove a len(n) dimension by selecting a particular index.
Swap: Swap the positions of two dimensions.
ZScore15: The Z-score threshold of values
ZScore30: The peak-to-peak along an axis.
"""
# Slice
@ -47,14 +50,19 @@ class FilterOperation(enum.StrEnum):
Pin = enum.auto()
PinIdx = enum.auto()
# Dimension
# Swizzle
Swap = enum.auto()
# Axis Filter
ZScore15 = enum.auto()
ZScore30 = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
FO = FilterOperation
return {
# Slice
@ -64,15 +72,20 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: 'a[0] → a',
FO.Pin: 'a[v] ⇝ a',
FO.PinIdx: 'a[i] → a',
# Reinterpret
# Swizzle
FO.Swap: 'a₁ ↔ a₂',
# Axis Filter
# FO.ZScore15: 'a[v₁:v₂] ∈ σ[1.5]',
# FO.ZScore30: 'a[v₁:v₂] ∈ σ[1.5]',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
FO = FilterOperation
return (
str(self),
@ -82,54 +95,83 @@ class FilterOperation(enum.StrEnum):
i,
)
@staticmethod
def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
Returns a `bpy.props.EnumProperty.items`-compatible list.
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(FilterOperation.by_info(info))
]
####################
# - Ops from Info
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
operations = []
ops = []
# Slice
if info.dims:
operations.append(FO.SliceIdx)
# Slice
ops += [FO.SliceIdx]
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
operations.append(FO.PinLen1)
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in [
len(dim_idx) for dim_idx in info.dims.values() if dim_idx is not None
]:
ops += [FO.PinLen1]
## Pin | PinIdx
## -> There must be a dimension, full stop.
if info.dims:
operations += [FO.Pin, FO.PinIdx]
# Pin
## -> There must be a dimension, full stop.
ops += [FO.Pin, FO.PinIdx]
# Reinterpret
## Swap
## -> There must be at least two dimensions.
if len(info.dims) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
# Swizzle
## Swap
## -> There must be at least two dimensions to swap between.
if len(info.dims) >= 2: # noqa: PLR2004
ops += [FO.Swap]
return operations
# Axis Filter
## ZScore
## -> Subjectively, it makes little sense with less than 5 numbers.
## -> Mathematically valid (I suppose) for 2. But not so useful.
# if any(
# (dim.has_idx_discrete(dim) or dim.has_idx_labels(dim))
# and len(dim_idx) > 5 # noqa: PLR2004
# for dim, dim_idx in info.dims.items()
# ):
# ops += [FO.ZScore15, FO.ZScore30]
return ops
####################
# - Computed Properties
####################
@property
@functools.cached_property
def func_args(self) -> list[sim_symbols.SimSymbol]:
FO = FilterOperation
return {
# Pin
FO.Pin: [sim_symbols.idx(None)],
FO.PinIdx: [sim_symbols.idx(None)],
# Swizzle
## -> Swap: JAX requires that swap dims be baked into the function.
# Axis Filter
# FO.ZScore15: [sim_symbols.idx(None)],
# FO.ZScore30: [sim_symbols.idx(None)],
}.get(self, [])
####################
# - Methods
####################
@property
@functools.cached_property
def num_dim_inputs(self) -> None:
"""Number of dimensions required as inputs to the operation's function."""
FO = FilterOperation
return {
# Slice
@ -139,11 +181,15 @@ class FilterOperation(enum.StrEnum):
FO.PinLen1: 1,
FO.Pin: 1,
FO.PinIdx: 1,
# Reinterpret
# Swizzle
FO.Swap: 2,
# Axis Filter
# FO.ZScore15: 1,
# FO.ZScore30: 1,
}[self]
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
"""The valid dimensions that can be selected between, fo each of the"""
FO = FilterOperation
match self:
# Slice
@ -171,30 +217,20 @@ class FilterOperation(enum.StrEnum):
case FO.Swap:
return info.dims
# TODO: ZScore
return []
def are_dims_valid(
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
if self.num_dim_inputs == 1:
return dim_0 in self.valid_dims(info)
if self.num_dim_inputs == 2: # noqa: PLR2004
valid_dims = self.valid_dims(info)
return dim_0 in valid_dims and dim_1 in valid_dims
return False
####################
# - UI
# - Implementations
####################
def jax_func(
self,
axis_0: int | None,
axis_1: int | None,
axis_1: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
"""Implements the identified filtering using `jax`."""
FO = FilterOperation
return {
# Pin
@ -210,13 +246,65 @@ class FilterOperation(enum.StrEnum):
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Dimension
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
# TODO: Axis Filters
## -> The jnp.compress() function is ideal for this kind of thing.
## -> The difficulty is that jit() requires output size to be known.
## -> One can set the size= parameter of compress.
## -> But how do we determine that?
}[self]
####################
# - Transforms
####################
def transform_func(
self,
func: ct.FuncFlow,
axis_0: int,
axis_1: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
) -> ct.FuncFlow | None:
"""Transform input function according to the current operation and output info characterization."""
FO = FilterOperation
match self:
# Slice
case FO.Slice | FO.SliceIdx if axis_0 is not None:
return func.compose_within(
self.jax_func(axis_0, slice_tuple=slice_tuple),
enclosing_func_output=func.func_output,
supports_jax=True,
)
# Pin
case FO.PinLen1 if axis_0 is not None:
return func.compose_within(
self.jax_func(axis_0),
enclosing_func_output=func.func_output,
supports_jax=True,
)
case FO.Pin | FO.PinIdx if axis_0 is not None:
return func.compose_within(
self.jax_func(axis_0),
enclosing_func_args=[sim_symbols.idx(None)],
enclosing_func_output=func.func_output,
supports_jax=True,
)
# Swizzle
case FO.Swap if axis_0 is not None and axis_1 is not None:
return func.compose_within(
self.jax_func(axis_0, axis_1),
enclosing_func_output=func.func_output,
supports_jax=True,
)
return None
def transform_info(
self,
info: ct.InfoFlow,
dim_0: sim_symbols.SimSymbol,
dim_1: sim_symbols.SimSymbol,
dim_1: sim_symbols.SimSymbol | None = None,
pin_idx: int | None = None,
slice_tuple: tuple[int, int, int] | None = None,
):
@ -225,9 +313,10 @@ class FilterOperation(enum.StrEnum):
FO.Slice: lambda: info.slice_dim(dim_0, slice_tuple),
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=0),
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
# TODO: Axis Filters
}[self]()

View File

@ -14,11 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements map operations for the `MapNode`."""
import enum
import typing as typ
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@ -27,6 +31,9 @@ from .. import contracts as ct
log = logger.get(__name__)
MT = spux.MathType
PT = spux.PhysicalType
class MapOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
@ -54,7 +61,8 @@ class MapOperation(enum.StrEnum):
SvdVals: Compute the singular values vector of the input matrix.
Inv: Compute the inverse matrix of the input matrix.
Tra: Compute the transpose matrix of the input matrix.
Qr: Compute the QR-factorized matrices of the input matrix.
QR_Q: Compute the QR-factorized matrices of the input matrix, and extract the 'Q' component.
QR_R: Compute the QR-factorized matrices of the input matrix, and extract the 'R' component.
Chol: Compute the Cholesky-factorized matrices of the input matrix.
Svd: Compute the SVD-factorized matrices of the input matrix.
"""
@ -64,6 +72,7 @@ class MapOperation(enum.StrEnum):
Imag = enum.auto()
Abs = enum.auto()
Sq = enum.auto()
Reciprocal = enum.auto()
Sqrt = enum.auto()
InvSqrt = enum.auto()
Cos = enum.auto()
@ -85,16 +94,17 @@ class MapOperation(enum.StrEnum):
SvdVals = enum.auto()
Inv = enum.auto()
Tra = enum.auto()
Qr = enum.auto()
Chol = enum.auto()
Svd = enum.auto()
QR_Q = enum.auto()
QR_R = enum.auto()
# Chol = enum.auto()
# Svd = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
"""A human-readable UI-oriented name."""
MO = MapOperation
return {
# By Number
@ -103,6 +113,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: '|v|',
MO.Sq: '',
MO.Sqrt: '√v',
MO.Reciprocal: '1/v',
MO.InvSqrt: '1/√v',
MO.Cos: 'cos v',
MO.Sin: 'sin v',
@ -123,9 +134,10 @@ class MapOperation(enum.StrEnum):
MO.SvdVals: 'svdvals V',
MO.Inv: 'V⁻¹',
MO.Tra: 'Vt',
MO.Qr: 'qr V',
MO.Chol: 'chol V',
MO.Svd: 'svd V',
MO.QR_Q: 'qr[Q] V',
MO.QR_R: 'qr[R] V',
# MO.Chol: 'chol V',
# MO.Svd: 'svd V',
}[value]
@staticmethod
@ -144,75 +156,105 @@ class MapOperation(enum.StrEnum):
i,
)
####################
# - Ops from Shape
####################
@staticmethod
def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
## TODO: By info, not shape.
## TODO: Check valid domains/mathtypes for some functions.
MO = MapOperation
element_ops = [
MO.Real,
MO.Imag,
MO.Abs,
MO.Sq,
MO.Sqrt,
MO.InvSqrt,
MO.Cos,
MO.Sin,
MO.Tan,
MO.Acos,
MO.Asin,
MO.Atan,
MO.Sinc,
def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`.
Returns a `bpy.props.EnumProperty.items`-compatible list.
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(MapOperation.from_info(info))
]
match (info.output.rows, info.output.cols):
case (1, 1):
return element_ops
####################
# - Derivation
####################
@staticmethod
def from_info(info: ct.InfoFlow) -> list[typ.Self]:
"""Derive valid mapping operations from the `InfoFlow` of the operand."""
MO = MapOperation
ops = []
case (_, 1):
return [*element_ops, MO.Norm2]
# Real/Imag
if info.output.mathtype is MT.Complex:
ops += [MO.Real, MO.Imag]
case (rows, cols) if rows == cols:
## TODO: Check hermitian/posdef for cholesky.
## - Can we even do this with just the output symbol approach?
return [
*element_ops,
MO.Det,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.Diag,
MO.EigVals,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Qr,
MO.Chol,
MO.Svd,
]
# Absolute Value
ops += [MO.Abs]
case (rows, cols):
return [
*element_ops,
MO.Cond,
MO.NormFro,
MO.Rank,
MO.SvdVals,
MO.Inv,
MO.Tra,
MO.Svd,
]
# Square
ops += [MO.Sq]
return []
# Reciprocal
if info.output.is_nonzero:
ops += [MO.Reciprocal]
# Square Root (Principal)
ops += [MO.Sqrt]
# Inverse Sqrt
if info.output.is_nonzero:
ops += [MO.InvSqrt]
# Cos/Sin/Tan/Sinc
if info.output.unit == spu.radian:
ops += [MO.Cos, MO.Sin, MO.Tan, MO.Sinc]
# Inverse Cos/Sin/Tan
## -> Presume complex-extensions that aren't limited.
if info.output.physical_type is PT.NonPhysical and info.output.unit is None:
ops += [MO.Acos, MO.Asin, MO.Atan]
# By Vector
if info.output.shape_len == 1:
ops += [MO.Norm2]
# By Matrix
if info.output.shape_len == 2: # noqa: PLR2004
if info.output.rows == info.output.cols:
ops += [MO.Det]
# Square Matrix
if info.output.rows == info.output.cols:
# Det
ops += [MO.Det]
# Diag
ops += [MO.Diag]
# Inv
ops += [MO.Inv]
# Cond
ops += [MO.Cond]
# NormFro
ops += [MO.NormFro]
# Rank
ops += [MO.Rank]
# EigVals
ops += [MO.EigVals]
# SvdVals
ops += [MO.EigVals]
# Tra
ops += [MO.Tra]
# QR
ops += [MO.QR_Q, MO.QR_R]
return ops
####################
# - Function Properties
# - Implementations
####################
@property
def sp_func(self):
"""Implement the mapping operation for sympy expressions."""
MO = MapOperation
return {
# By Number
@ -221,6 +263,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: lambda expr: sp.Abs(expr),
MO.Sq: lambda expr: expr**2,
MO.Sqrt: lambda expr: sp.sqrt(expr),
MO.Reciprocal: lambda expr: 1 / expr,
MO.InvSqrt: lambda expr: 1 / sp.sqrt(expr),
MO.Cos: lambda expr: sp.cos(expr),
MO.Sin: lambda expr: sp.sin(expr),
@ -240,19 +283,25 @@ class MapOperation(enum.StrEnum):
MO.Rank: lambda expr: expr.rank(),
# Matrix -> Vec
MO.Diag: lambda expr: expr.diagonal(),
MO.EigVals: lambda expr: sp.Matrix(list(expr.eigenvals().keys())),
MO.EigVals: lambda expr: sp.ImmutableMatrix(list(expr.eigenvals().keys())),
MO.SvdVals: lambda expr: expr.singular_values(),
# Matrix -> Matrix
MO.Inv: lambda expr: expr.inv(),
MO.Tra: lambda expr: expr.T,
# Matrix -> Matrices
MO.Qr: lambda expr: expr.QRdecomposition(),
MO.Chol: lambda expr: expr.cholesky(),
MO.Svd: lambda expr: expr.singular_value_decomposition(),
MO.QR_Q: lambda expr: expr.QRdecomposition()[0],
MO.QR_R: lambda expr: expr.QRdecomposition()[1],
# MO.Chol: lambda expr: expr.cholesky(),
# MO.Svd: lambda expr: expr.singular_value_decomposition(),
}[self]
@property
def jax_func(self):
def jax_func(
self,
) -> typ.Callable[
[jtyp.Shaped[jtyp.Array, '...'], int], jtyp.Shaped[jtyp.Array, '...']
]:
"""Implements the identified mapping using `jax`."""
MO = MapOperation
return {
# By Number
@ -261,6 +310,7 @@ class MapOperation(enum.StrEnum):
MO.Abs: lambda expr: jnp.abs(expr),
MO.Sq: lambda expr: jnp.square(expr),
MO.Sqrt: lambda expr: jnp.sqrt(expr),
MO.Reciprocal: lambda expr: 1 / expr,
MO.InvSqrt: lambda expr: 1 / jnp.sqrt(expr),
MO.Cos: lambda expr: jnp.cos(expr),
MO.Sin: lambda expr: jnp.sin(expr),
@ -286,80 +336,231 @@ class MapOperation(enum.StrEnum):
MO.Inv: lambda expr: jnp.linalg.inv(expr),
MO.Tra: lambda expr: jnp.matrix_transpose(expr),
# Matrix -> Matrices
MO.Qr: lambda expr: jnp.linalg.qr(expr),
MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
MO.Svd: lambda expr: jnp.linalg.svd(expr),
MO.QR_Q: lambda expr: jnp.linalg.qr(expr)[0],
MO.QR_R: lambda expr: jnp.linalg.qr(expr, mode='r'),
# MO.Chol: lambda expr: jnp.linalg.cholesky(expr),
# MO.Svd: lambda expr: jnp.linalg.svd(expr),
}[self]
####################
# - Transforms: FlowKind
####################
def transform_func(self, func: ct.FuncFlow) -> ct.FuncFlow:
"""Transform input function according to the current operation and output info characterization."""
return func.compose_within(
self.jax_func,
enclosing_func_output=self.transform_output(func.func_output),
supports_jax=True,
)
def transform_info(self, info: ct.InfoFlow):
"""Transform the `InfoFlow` characterizing the output."""
return info.update(output=self.transform_output(info.output))
def transform_params(self, params: ct.ParamsFlow):
"""Transform the incoming function parameters to include output arguments."""
return params
####################
# - Transforms: Symbolic
####################
def transform_output(self, sym: sim_symbols.SimSymbol): # noqa: PLR0911
"""Transform the `SimSymbol` characterizing the output."""
MO = MapOperation
return {
dm = sym.domain
match self:
# By Number
MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
MO.Sq: lambda: info,
MO.Sqrt: lambda: info,
MO.InvSqrt: lambda: info,
MO.Cos: lambda: info,
MO.Sin: lambda: info,
MO.Tan: lambda: info,
MO.Acos: lambda: info,
MO.Asin: lambda: info,
MO.Atan: lambda: info,
MO.Sinc: lambda: info,
# By Vector
MO.Norm2: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
case MO.Real:
return sym.update(
mathtype=MT.Real,
domain=dm.real,
)
case MO.Imag:
return sym.update(
mathtype=MT.Real,
domain=dm.imag,
)
case MO.Abs:
return sym.update(
mathtype=MT.Real,
domain=dm.abs,
)
case MO.Sq:
return sym.update(
domain=dm**2,
)
case MO.Reciprocal:
orig_unit = sym.unit
new_unit = 1 / orig_unit if orig_unit is not None else None
new_phy_type = PT.from_unit(new_unit, optional=True)
return sym.update(
physical_type=new_phy_type,
unit=new_unit,
domain=dm.reciprocal,
)
case MO.Sqrt:
## TODO: Complex -> Real MathType
return sym.update(domain=dm ** sp.Rational(1, 2))
case MO.InvSqrt:
## TODO: Complex -> Real MathType
return sym.update(domain=(dm ** sp.Rational(1, 2)).reciprocal)
case MO.Cos:
return sym.update(
physical_type=PT.NonPhysical,
unit=None,
domain=dm.cos,
)
case MO.Sin:
return sym.update(
physical_type=PT.NonPhysical,
unit=None,
domain=dm.sin,
)
case MO.Tan:
return sym.update(
physical_type=PT.NonPhysical,
unit=None,
domain=dm.tan,
)
case MO.Acos:
return sym.update(
mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
physical_type=PT.Angle,
unit=spu.radian,
domain=dm.acos,
)
case MO.Asin:
return sym.update(
mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
physical_type=PT.Angle,
unit=spu.radian,
domain=dm.asin,
)
case MO.Atan:
return sym.update(
mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
physical_type=PT.Angle,
unit=spu.radian,
domain=dm.atan,
)
case MO.Sinc:
return sym.update(
physical_type=PT.NonPhysical,
unit=None,
domain=dm.sinc,
)
# By Vector/Covector
case MO.Norm2:
size = max([sym.rows, sym.cols])
return sym.update(
mathtype=MT.Real,
rows=1,
cols=1,
domain=(size * dm**2) ** sp.Rational(1, 2),
)
# By Matrix
MO.Det: lambda: info.update_output(
rows=1,
cols=1,
),
MO.Cond: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
),
MO.NormFro: lambda: info.update_output(
mathtype=spux.MathType.Real,
rows=1,
cols=1,
# Interval
interval_finite_re=(0, sim_symbols.float_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
MO.Rank: lambda: info.update_output(
mathtype=spux.MathType.Integer,
rows=1,
cols=1,
physical_type=spux.PhysicalType.NonPhysical,
unit=None,
# Interval
interval_finite_re=(0, sim_symbols.int_max),
interval_inf=(False, True),
interval_closed=(True, False),
),
# Matrix -> Vector ## TODO: ALL OF THESE
MO.Diag: lambda: info,
MO.EigVals: lambda: info,
MO.SvdVals: lambda: info,
# Matrix -> Matrix ## TODO: ALL OF THESE
MO.Inv: lambda: info,
MO.Tra: lambda: info,
# Matrix -> Matrices ## TODO: ALL OF THESE
MO.Qr: lambda: info,
MO.Chol: lambda: info,
MO.Svd: lambda: info,
}[self]()
case MO.Det:
## -> NOTE: Determinant only valid for square matrices.
size = sym.rows
orig_unit = sym.unit
new_unit = orig_unit**size if orig_unit is not None else None
_new_phy_type = PT.from_unit(new_unit, optional=True)
new_phy_type = (
_new_phy_type if _new_phy_type is not None else PT.NonPhysical
)
return sym.update(
physical_type=new_phy_type,
unit=new_unit,
rows=1,
cols=1,
domain=(size * dm**2) ** sp.Rational(1, 2),
)
case MO.Cond:
return sym.update(
mathtype=MT.Real,
physical_type=PT.NonPhysical,
unit=None,
rows=1,
cols=1,
domain=spux.BlessedSet(sp.Interval(1, sp.oo)),
)
case MO.NormFro:
return sym.update(
mathtype=MT.Real,
rows=1,
cols=1,
domain=(sym.rows * sym.cols * abs(dm) ** 2) ** sp.Rational(1, 2),
)
case MO.Rank:
return sym.update(
mathtype=MT.Integer,
physical_type=PT.NonPhysical,
unit=None,
rows=1,
cols=1,
domain=spux.BlessedSet(sp.Range(0, min([sym.rows, sym.cols]) + 1)),
)
case MO.Diag:
return sym.update(cols=1)
case MO.EigVals:
## TODO: Gershgorin circle theorem?
return spux.BlessedSet(sp.Complexes)
case MO.SvdVals:
## TODO: Domain bound on singular values?
## -- We might also consider a 'nonzero singvals' operation.
## -- Since singular values can be zero just fine.
return sym.update(
mathtype=MT.Real,
cols=1,
domain=spux.BlessedSet(sp.Interval(0, sp.oo)),
)
case MO.Inv:
## -> Defined: Square non-singular matrices.
orig_unit = sym.unit
new_unit = 1 / orig_unit if orig_unit is not None else None
new_phy_type = PT.from_unit(new_unit, optional=True)
return sym.update(
physical_type=new_phy_type,
unit=new_unit,
domain=sym.mathtype.symbolic_set,
)
case MO.Tra:
return sym.update(
rows=sym.cols,
cols=sym.rows,
)
case MO.QR_Q:
return sym.update(
mathtype=MT.Complex if sym.mathtype is MT.Complex else MT.Real,
physical_type=PT.NonPhysical,
unit=None,
cols=min([sym.rows, sym.cols]),
domain=(
spux.BlessedSet(spux.ComplexRegion(sp.Interval(-1, 1) ** 2))
if sym.mathtype is MT.Complex
else spux.BlessedSet(sp.Interval(-1, 1))
),
)
case MO.QR_R:
return sym

View File

@ -14,12 +14,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements transform operations for the `MapNode`."""
import enum
import functools
import typing as typ
import jax.numpy as jnp
import sympy as sp
import sympy.physics.quantum as spq
import sympy.physics.units as spu
from blender_maxwell.utils import logger, sim_symbols
@ -29,6 +31,113 @@ from .. import contracts as ct
log = logger.get(__name__)
MT = spux.MathType
PT = spux.PhysicalType
# @functools.lru_cache(maxsize=1024)
# def expand_shapes(
# shape_l: tuple[int, ...], shape_r: tuple[int, ...]
# ) -> tuple[tuple[int, ...], tuple[int, ...]]:
# """Transform each shape to two new shapes, whose lengths are identical, and for which operations between are well-defined, but which occupies the same amount of memory."""
# axes = dict(
# reversed(
# list(
# itertools.zip_longest(reversed(shape_l), reversed(shape_r), fillvalue=1)
# )
# )
# )
#
# return (tuple(axes.keys()), tuple(axes.values()))
#
#
# @functools.lru_cache(maxsize=1024)
# def broadcast_shape(
# expanded_shape_l: tuple[int, ...], expanded_shape_r: tuple[int, ...]
# ) -> tuple[int, ...] | None:
# """Deduce a common shape that an object of both expanded shapes can be broadcast to."""
# new_shape = []
# for ax_l, ax_r in itertools.zip_longest(
# expanded_shape_l, expanded_shape_r, fillvalue=1
# ):
# if ax_l == 1 or ax_r == 1 or ax_l == ax_r: # noqa: PLR1714
# new_shape.append(max([ax_l, ax_r]))
# else:
# return None
#
# return tuple(new_shape)
#
#
# def broadcast_to_shape(
# M: sp.NDimArray, compatible_shape: tuple[int, ...]
# ) -> spux.SympyType:
# """Conform an array with expanded shape to the given shape, expanding any axes that need expanding."""
# L = M
#
# incremental_shape = ()
# for orig_ax, new_ax in reversed(zip(M.shape, compatible_shape, strict=True)):
# incremental_shape = (new_ax, *incremental_shape)
# if orig_ax == 1 and new_ax > 1:
# _L = sp.flatten(L) if L.shape == () else L.tolist()
#
# L = sp.ImmutableDenseNDimArray(_L * new_ax).reshape(*compatible_shape)
#
# return L
#
#
# def sp_operation(op, lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType | None:
# if not isinstance(lhs, sp.MatrixBase | sp.NDimArray) and not isinstance(
# lhs, sp.MatrixBase | sp.NDimArray
# ):
# return op(lhs, rhs)
#
# # Deduce Expanded L/R Arrays
# ## -> This conforms the shape of both operands to broadcastable shape.
# ## -> The actual memory usage from doing this remains identical.
# _L = sp.ImmutableDenseNDimArray(lhs)
# _R = sp.ImmutableDenseNDimArray(rhs)
# expanded_shape_l, expanded_shape_r = expand_shapes(_L.shape, _R.shape)
# _L = _L.reshape(*expanded_shape_l)
# _R = _R.reshape(*expanded_shape_r)
#
# # Broadcast Expanded L/R Arrays
# ## -> Expanded dimensions will be conformed to the max of each.
# ## -> This conforms the shape of both operands to broadcastable shape.
# broadcasted_shape = broadcast_to_shape(expanded_shape_l, expanded_shape_r)
# if broadcasted_shape is None:
# return None
#
# L = broadcast_to_shape(_L, broadcasted_shape)
# R = broadcast_to_shape(_R, broadcasted_shape)
#
# # Run Elementwise Operation
# ## -> Elementwise operations can now cleanly run between both operands.
# output = op(L, R)
# if output.shape in [1, 2]:
# return sp.ImmutableMatrix(output.tomatrix())
# return output
#
#
# def hadamard_product(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType | None:
# match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
# case (False, False):
# return lhs * rhs
#
# case (True, False):
# return lhs.applyfunc(lambda el: el * rhs)
#
# case (False, True):
# return rhs.applyfunc(lambda el: lhs * el)
#
# case (True, True) if lhs.shape == rhs.shape:
# common_shape = lhs.shape
# return sp.ImmutableMatrix(
# *common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
# )
#
# msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
# raise ValueError(msg)
def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
"""Implement the Hadamard Power.
@ -37,8 +146,7 @@ def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
"""
match (isinstance(lhs, sp.MatrixBase), isinstance(rhs, sp.MatrixBase)):
case (False, False):
msg = f"Hadamard Power for two scalars is valid, but shouldn't be used - use normal power instead: {lhs} | {rhs}"
raise ValueError(msg)
return lhs**rhs
case (True, False):
return lhs.applyfunc(lambda el: el**rhs)
@ -52,9 +160,8 @@ def hadamard_power(lhs: spux.SympyType, rhs: spux.SympyType) -> spux.SympyType:
*common_shape, lambda i, j: lhs[i, j] ** rhs[i, j]
)
case _:
msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
raise ValueError(msg)
msg = f'Incompatible lhs and rhs for hadamard power: {lhs} | {rhs}'
raise ValueError(msg)
class BinaryOperation(enum.StrEnum):
@ -74,7 +181,6 @@ class BinaryOperation(enum.StrEnum):
VecVecOuter: Vector-vector outer product.
LinSolve: Solve a linear system.
LsqSolve: Minimize error of an underdetermined linear system.
VecMatOuter: Vector-matrix outer product.
MatMatDot: Matrix-matrix dot product.
"""
@ -100,9 +206,6 @@ class BinaryOperation(enum.StrEnum):
LinSolve = enum.auto()
LsqSolve = enum.auto()
# Vector | Matrix
VecMatOuter = enum.auto()
# Matrix | Matrix
MatMatDot = enum.auto()
@ -132,17 +235,20 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: '𝐋 𝐫',
BO.LsqSolve: 'argminₓ∥𝐋𝐱𝐫∥₂',
# Vector | Matrix
BO.VecMatOuter: '𝐋𝐫',
# Matrix | Matrix
BO.MatMatDot: '𝐋 · 𝐑',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
@functools.cached_property
def name(self) -> str:
"""No icons."""
return BinaryOperation.to_name(self)
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
BO = BinaryOperation
@ -154,8 +260,9 @@ class BinaryOperation(enum.StrEnum):
i,
)
@staticmethod
def bl_enum_elements(
self, info_l: ct.InfoFlow, info_r: ct.InfoFlow
info_l: ct.InfoFlow, info_r: ct.InfoFlow
) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
@ -163,14 +270,14 @@ class BinaryOperation(enum.StrEnum):
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(BinaryOperation.by_infos(info_l, info_r))
for i, operation in enumerate(BinaryOperation.from_infos(info_l, info_r))
]
####################
# - Ops from Shape
####################
@staticmethod
def by_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]:
def from_infos(info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> list[typ.Self]: # noqa: C901, PLR0912
"""Deduce valid binary operations from the shapes of the inputs."""
BO = BinaryOperation
ops = []
@ -190,14 +297,15 @@ class BinaryOperation(enum.StrEnum):
ops += [BO.Mul, BO.Div]
case (ordl, ordr, True) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
case (ordl, ordr, True) if ordl > 0 and ordr > 0:
case (ordl, ordr, _) if ordl > 0 and ordr > 0:
## TODO: _ is not correct
ops += [BO.HadamMul, BO.HadamDiv]
case (ordl, ordr, False) if ordl == 0 and ordr == 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr == 0:
ops += [BO.Mul]
case (ordl, ordr, True) if ordl == 0 and ordr > 0:
case (ordl, ordr, False) if ordl == 0 and ordr > 0:
ops += [BO.Mul]
case (ordl, ordr, False) if ordl > 0 and ordr > 0:
ops += [BO.HadamMul]
@ -212,7 +320,7 @@ class BinaryOperation(enum.StrEnum):
case (ordl, ordr, _) if ordl == 0 and ordr == 0:
ops += [BO.Pow]
case (ordl, ordr, spux.MathType.Integer) if (
case (ordl, ordr, MT.Integer) if (
ordl > 0 and ordr == 0 and info_l.output.rows == info_l.output.cols
):
ops += [BO.Pow, BO.HadamPow]
@ -220,29 +328,30 @@ class BinaryOperation(enum.StrEnum):
case _:
ops += [BO.HadamPow]
# Atan2
if (
(
info_l.output.mathtype is not MT.Complex
and info_r.output.mathtype is not MT.Complex
)
and (
info_l.output.physical_type is PT.Length
and info_r.output.physical_type is PT.Length
)
or (
info_l.output.physical_type is PT.NonPhysical
and info_l.output.unit is None
and info_r.output.physical_type is PT.NonPhysical
and info_r.output.unit is None
)
):
ops += [BO.Atan2]
# Operations by-Output Length
match (
info_l.output.shape_len,
info_r.output.shape_len,
):
# Number | Number
case (0, 0) if info_l.is_scalar and info_r.is_scalar:
# atan2: PhysicalType Must Both be Length | NonPhysical
## -> atan2() produces radians from Cartesian coordinates.
## -> This wouldn't make sense on non-Length / non-Unitless.
if (
info_l.output.physical_type is spux.PhysicalType.Length
and info_r.output.physical_type is spux.PhysicalType.Length
) or (
info_l.output.physical_type is spux.PhysicalType.NonPhysical
and info_l.output.unit is None
and info_r.output.physical_type is spux.PhysicalType.NonPhysical
and info_r.output.unit is None
):
ops += [BO.Atan2]
return ops
# Vector | Vector
case (1, 1) if info_l.compare_dims_identical(info_r):
outl = info_l.output
@ -274,19 +383,11 @@ class BinaryOperation(enum.StrEnum):
## -> Works great element-wise.
## -> Enforce that both are 3x1 or 1x3.
## -> See https://docs.sympy.org/latest/modules/matrices/matrices.html#sympy.matrices.matrices.MatrixBase.cross
if (outl.rows == 3 and outr.rows == 3) or (
outl.cols == 3 and outl.cols == 3
if (outl.rows == 3 and outr.rows == 3) or ( # noqa: PLR2004
outl.cols == 3 and outl.cols == 3 # noqa: PLR2004
):
ops += [BO.Cross]
# Vector | Matrix
## -> We can't do per-element outer product.
## -> However, it's still super useful on its own.
case (1, 2) if info_l.compare_dims_identical(
info_r
) and info_l.order == 1 and info_r.order == 2:
ops += [BO.VecMatOuter]
# Matrix | Vector
case (2, 1) if info_l.compare_dims_identical(info_r):
# Mat-Vec Dot: Enforce RHS Column Vector
@ -309,17 +410,20 @@ class BinaryOperation(enum.StrEnum):
"""Deduce an appropriate sympy-based function that implements the binary operation for symbolic inputs."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
## BODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: lambda exprs: exprs[0] * exprs[1],
BO.Div: lambda exprs: exprs[0] / exprs[1],
BO.Pow: lambda exprs: exprs[0] ** exprs[1],
BO.Pow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
exprs[1].applyfunc(lambda el: 1 / el)
),
BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
BO.Atan2: lambda exprs: sp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: (exprs[0].T @ exprs[1])[0],
@ -328,27 +432,27 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: lambda exprs: exprs[0].solve(exprs[1]),
BO.LsqSolve: lambda exprs: exprs[0].solve_least_squares(exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: spq.TensorProduct(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: exprs[0] @ exprs[1],
}[self]
@property
def unit_func(self):
def scalar_sp_func(self):
"""The binary function to apply to both unit expressions, in order to deduce the unit expression of the output."""
BO = BinaryOperation
## TODO: Make this compatible with sp.Matrix inputs
## BODO: Make this compatible with sp.Matrix inputs
return {
# Number | Number
BO.Mul: BO.Mul.sp_func,
BO.Div: BO.Div.sp_func,
BO.Pow: BO.Pow.sp_func,
# Elements | Elements
BO.Add: BO.Add.sp_func,
BO.Sub: BO.Sub.sp_func,
BO.Add: lambda exprs: exprs[0],
BO.Sub: lambda exprs: exprs[0],
BO.HadamMul: BO.Mul.sp_func,
BO.HadamDiv: BO.Div.sp_func,
BO.HadamPow: BO.Pow.sp_func,
# BO.HadamPow: lambda exprs: sp.HadamardPower(exprs[0], exprs[1]),
BO.Atan2: lambda _: spu.radian,
# Vector | Vector
@ -360,8 +464,6 @@ class BinaryOperation(enum.StrEnum):
## -> Therefore, A \ b must have the units [b]/[A].
BO.LinSolve: lambda exprs: exprs[1] / exprs[0],
BO.LsqSolve: lambda exprs: exprs[1] / exprs[0],
# Vector | Matrix
BO.VecMatOuter: BO.Mul.sp_func,
# Matrix | Matrix
BO.MatMatDot: BO.Mul.sp_func,
}[self]
@ -369,7 +471,7 @@ class BinaryOperation(enum.StrEnum):
@property
def jax_func(self):
"""Deduce an appropriate jax-based function that implements the binary operation for array inputs."""
## TODO: Scale the units of one side to the other.
## BODO: Scale the units of one side to the other.
BO = BinaryOperation
return {
@ -380,11 +482,9 @@ class BinaryOperation(enum.StrEnum):
# Elements | Elements
BO.Add: lambda exprs: exprs[0] + exprs[1],
BO.Sub: lambda exprs: exprs[0] - exprs[1],
BO.HadamMul: lambda exprs: exprs[0].multiply_elementwise(exprs[1]),
BO.HadamDiv: lambda exprs: exprs[0].multiply_elementwise(
exprs[1].applyfunc(lambda el: 1 / el)
),
BO.HadamPow: lambda exprs: hadamard_power(exprs[0], exprs[1]),
BO.HadamMul: lambda exprs: exprs[0] * exprs[1],
BO.HadamDiv: lambda exprs: exprs[0] / exprs[1],
BO.HadamPow: lambda exprs: exprs[0] ** exprs[1],
BO.Atan2: lambda exprs: jnp.atan2(exprs[1], exprs[0]),
# Vector | Vector
BO.VecVecDot: lambda exprs: jnp.linalg.vecdot(exprs[0], exprs[1]),
@ -393,8 +493,6 @@ class BinaryOperation(enum.StrEnum):
# Matrix | Vector
BO.LinSolve: lambda exprs: jnp.linalg.solve(exprs[0], exprs[1]),
BO.LsqSolve: lambda exprs: jnp.linalg.lstsq(exprs[0], exprs[1]),
# Vector | Matrix
BO.VecMatOuter: lambda exprs: jnp.outer(exprs[0], exprs[1]),
# Matrix | Matrix
BO.MatMatDot: lambda exprs: jnp.matmul(exprs[0], exprs[1]),
}[self]
@ -415,62 +513,172 @@ class BinaryOperation(enum.StrEnum):
else:
norm_func_r = func_r
return (func_l, norm_func_r).compose_within(
return (func_l | norm_func_r).compose_within(
self.jax_func,
enclosing_func_output=self.transform_outputs(
func_l.func_output, norm_func_r.func_output
func_l.func_output, func_r.func_output
),
supports_jax=True,
)
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow):
"""Deduce the output information by using `self.sp_func` to operate on the two output `SimSymbol`s, then capturing the information associated with the resulting expression.
def transform_infos(self, info_l: ct.InfoFlow, info_r: ct.InfoFlow) -> ct.InfoFlow:
"""Transform the `InfoFlow` characterizing the output."""
if len(info_l.dims) == 0:
dims = info_r.dims
elif len(info_r.dims) == 0:
dims = info_l.dims
else:
dims = info_l.dims
Warnings:
`self` MUST be an element of `BinaryOperation.by_infos(info_l, info_r).
If not, bad things will happen.
"""
return info_l.operate_output(
info_r,
lambda a, b: self.sp_func([a, b]),
lambda a, b: self.unit_func([a, b]),
return ct.InfoFlow(
dims=dims,
output=self.transform_outputs(info_l.output, info_r.output),
pinned_values=info_l.pinned_values | info_r.pinned_values,
)
def transform_params(
self, params_l: ct.ParamsFlow, params_r: ct.ParamsFlow
) -> ct.ParamsFlow:
"""Aggregate the incoming function parameters for the output."""
return params_l | params_r
####################
# - InfoFlow Transform
# - Other Transforms
####################
def transform_outputs(
self, output_l: sim_symbols.SimSymbol, output_r: sim_symbols.SimSymbol
self, sym_l: sim_symbols.SimSymbol, sym_r: sim_symbols.SimSymbol
) -> sim_symbols.SimSymbol:
# TO = TransformOperation
return None
# match self:
# # Number | Number
# case TO.Mul:
# return
# case TO.Div:
# case TO.Pow:
BO = BinaryOperation
# # Elements | Elements
# Add = enum.auto()
# Sub = enum.auto()
# HadamMul = enum.auto()
# HadamPow = enum.auto()
# HadamDiv = enum.auto()
# Atan2 = enum.auto()
if sym_l.sym_name == sym_r.sym_name:
name = sym_l.sym_name
else:
name = sim_symbols.SimSymbolName.Expr
# # Vector | Vector
# VecVecDot = enum.auto()
# Cross = enum.auto()
# VecVecOuter = enum.auto()
dm_l = sym_l.domain
dm_r = sym_r.domain
match self:
case BO.Mul | BO.Div | BO.Pow | BO.HadamDiv:
# dm = self.scalar_sp_func([dm_l, dm_r])
unit_factor = self.scalar_sp_func(
[sym_l.unit_factor, sym_r.unit_factor]
)
unit = unit_factor if spux.uses_units(unit_factor) else None
physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
# # Matrix | Vector
# LinSolve = enum.auto()
# LsqSolve = enum.auto()
mathtype = MT.combine(
MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
)
return sim_symbols.SimSymbol(
sym_name=name,
mathtype=mathtype,
physical_type=physical_type,
unit=unit,
rows=max([sym_l.rows, sym_r.rows]),
cols=max([sym_l.cols, sym_r.cols]),
depths=tuple(
[
max([dp_l, dp_r])
for dp_l, dp_r in zip(
sym_l.depths, sym_r.depths, strict=True
)
]
),
domain=spux.BlessedSet(mathtype.symbolic_set),
)
# # Vector | Matrix
# VecMatOuter = enum.auto()
case BO.Add | BO.Sub:
fac_r_unit_to_l_unit = sp.S(spux.scaling_factor(sym_l.unit, sym_r.unit))
# # Matrix | Matrix
# MatMatDot = enum.auto()
dm = self.scalar_sp_func([dm_l, dm_r * fac_r_unit_to_l_unit])
unit_factor = self.scalar_sp_func(
[sym_l.unit_factor, sym_r.unit_factor]
)
unit = unit_factor if spux.uses_units(unit_factor) else None
physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
return sym_l.update(
sym_name=name,
mathtype=MT.from_symbolic_set(dm.bset),
physical_type=physical_type,
unit=None if unit_factor == 1 else unit_factor,
domain=dm,
)
case BO.HadamMul | BO.HadamPow:
# fac_r_unit_to_l_unit = sp.S(spux.scaling_factor(sym_l.unit, sym_r.unit))
mathtype = MT.combine(
MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
)
# dm = self.scalar_sp_func([dm_l, dm_r * fac_r_unit_to_l_unit])
unit_factor = self.scalar_sp_func(
[sym_l.unit_factor, sym_r.unit_factor]
)
unit = unit_factor if spux.uses_units(unit_factor) else None
physical_type = PT.from_unit(unit, optional=True, optional_nonphy=True)
return sym_l.update(
sym_name=name,
mathtype=mathtype,
physical_type=physical_type,
unit=None if unit_factor == 1 else unit_factor,
domain=spux.BlessedSet(mathtype.symbolic_set),
)
case BO.Atan2:
dm = dm_l.atan2(dm_r)
return sym_l.update(
sym_name=name,
mathtype=MT.from_symbolic_set(dm.bset),
physical_type=PT.Angle,
unit=spu.radian,
domain=dm,
)
case BO.VecVecDot:
_dm = dm_l * dm_r
dm = _dm + _dm
return sym_l.update(
sym_name=name,
domain=dm,
rows=1,
cols=1,
)
case BO.Cross:
_dm = dm_l * dm_r
dm = _dm + _dm
return sym_l.update(
sym_name=name,
domain=dm,
)
case BO.VecVecOuter:
dm = dm_l * dm_r
return sym_l.update(
sym_name=name,
domain=dm,
rows=max([sym_l.rows, sym_r.rows]),
cols=max([sym_l.cols, sym_r.cols]),
)
case BO.LinSolve | BO.LsqSolve:
mathtype = MT.combine(
MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
).symbolic_set
dm = spux.BlessedSet(mathtype.symbolic_set)
return sym_r.update(mathtype=mathtype, domain=dm)
case BO.MatMatDot:
mathtype = MT.combine(
MT.from_symbolic_set(dm_l.bset), MT.from_symbolic_set(dm_r.bset)
).symbolic_set
dm = spux.BlessedSet(mathtype.symbolic_set)
return sym_r.update(mathtype=mathtype, domain=dm)

View File

@ -18,6 +18,7 @@ import enum
import typing as typ
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import logger, sim_symbols
@ -27,8 +28,30 @@ from .. import contracts as ct
log = logger.get(__name__)
MT = spux.MathType
PT = spux.PhysicalType
class ReduceOperation(enum.StrEnum):
"""Valid operations for the `ReduceMathNode`.
Attributes:
Count: The number of discrete elements along an axis.
Mean: The average along an axis.
Std: The standard deviation along an axis.
Var: The variance along an axis.
Min: The minimum value along an axis.
Q25: The `25%` quantile along an axis.
Medium: The `50%` quantile along an axis.
Q75: The `75%` quantile along an axis.
Max: The `75%` quantile along an axis.
P2P: The peak-to-peak range along an axis.
Z15ToZ15: The range between z-scores of 1.5 in each direction.
Z30ToZ30: The range between z-scores of 3.0 in each direction.
Sum: The sum along an axis.
Prod: The product along an axis.
"""
# Summary
Count = enum.auto()
@ -37,15 +60,15 @@ class ReduceOperation(enum.StrEnum):
Std = enum.auto()
Var = enum.auto()
StdErr = enum.auto()
Min = enum.auto()
Q25 = enum.auto()
Median = enum.auto()
Q75 = enum.auto()
Max = enum.auto()
Mode = enum.auto()
P2P = enum.auto()
Z15ToZ15 = enum.auto()
Z30ToZ30 = enum.auto()
# Reductions
Sum = enum.auto()
@ -61,22 +84,27 @@ class ReduceOperation(enum.StrEnum):
return {
# Summary
RO.Count: '# [a]',
RO.Mode: 'mode [a]',
# Statistics
RO.Mean: 'μ [a]',
RO.Std: 'σ [a]',
RO.Var: 'σ² [a]',
RO.StdErr: 'stderr [a]',
RO.Min: 'min [a]',
RO.Q25: 'q₂₅ [a]',
RO.Median: 'median [a]',
RO.Q75: 'q₇₅ [a]',
RO.Min: 'max [a]',
RO.Max: 'max [a]',
RO.P2P: 'p2p [a]',
RO.Z15ToZ15: 'σ[1.5] [a]',
RO.Z30ToZ30: 'σ[3.0] [a]',
# Reductions
RO.Sum: 'sum [a]',
RO.Prod: 'prod [a]',
}[value]
@property
def name(self) -> str:
return ReduceOperation.to_name(self)
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
@ -93,24 +121,165 @@ class ReduceOperation(enum.StrEnum):
i,
)
@staticmethod
def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
Returns a `bpy.props.EnumProperty.items`-compatible list.
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(ReduceOperation.from_info(info))
]
####################
# - Derivation
####################
@staticmethod
def from_info(info: ct.InfoFlow) -> list[typ.Self]:
"""Derive valid reduction operations from the `InfoFlow` of the operand."""
pass
RO = ReduceOperation
ops = []
if info.dims and any(
info.has_idx_discrete(dim) or info.has_idx_labels(dim) for dim in info.dims
):
# Summary
ops += [RO.Count]
# Statistics
ops += [
RO.Mean,
RO.Std,
RO.Var,
RO.Min,
RO.Q25,
RO.Median,
RO.Q75,
RO.Max,
RO.P2P,
RO.Z15ToZ15,
RO.Z30ToZ30,
]
# Reductions
ops += [RO.Sum, RO.Prod]
## I know, they can be combined.
## But they may one day need more checks.
return ops
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
"""Valid dimensions that can be reduced."""
return [
dim
for dim in info.dims
if info.has_idx_discrete(dim) or info.has_idx_labels(dim)
]
####################
# - Composable Functions
####################
@property
def jax_func(self):
def jax_func(
self,
) -> typ.Callable[
[jtyp.Shaped[jtyp.Array, '...'], int], jtyp.Shaped[jtyp.Array, '...']
]:
"""Implements the identified reduction using `jax`."""
RO = ReduceOperation
return {}[self]
return {
# Summary
RO.Count: lambda el, axis: el.shape[axis],
# RO.Mode: lambda el, axis: jsc.stats.mode(el, axis=axis).
# Statistics
RO.Mean: lambda el, axis: jnp.mean(el, axis=axis),
RO.Std: lambda el, axis: jnp.std(el, axis=axis),
RO.Var: lambda el, axis: jnp.var(el, axis=axis),
RO.Min: lambda el, axis: jnp.min(el, axis=axis),
RO.Q25: lambda el, axis: jnp.quantile(el, 0.25, axis=axis),
RO.Median: lambda el, axis: jnp.median(el, axis=axis),
RO.Q75: lambda el, axis: jnp.quantile(el, 0.75, axis=axis),
RO.Max: lambda el, axis: jnp.max(el, axis=axis),
RO.P2P: lambda el, axis: jnp.ptp(el, axis=axis),
RO.Z15ToZ15: lambda el, axis: 2 * (3 / 2) * jnp.std(el, axis=axis),
RO.Z30ToZ30: lambda el, axis: 2 * 3 * jnp.std(el, axis=axis),
# Statistics
RO.Sum: lambda el, axis: jnp.sum(el, axis=axis),
RO.Prod: lambda el, axis: jnp.prod(el, axis=axis),
}[self]
####################
# - Transforms
####################
def transform_info(self, info: ct.InfoFlow):
pass
def transform_func(self, func: ct.InfoFlow):
"""Transform the lazy `FuncFlow` to reduce the input."""
return func.compose_within(
self.jax_func,
enclosing_func_args=(sim_symbols.idx(None),),
enclosing_func_output=self.transform_output(func.func_output),
supports_jax=True,
)
def transform_info(self, info: ct.InfoFlow, dim: sim_symbols.SimSymbol):
"""Transform the characterizing `InfoFlow` of the reduced operand."""
return info.delete_dim(dim).update(output=self.transform_output(info.output))
def transform_params(self, params: ct.ParamsFlow, axis: int) -> None:
"""Transform the characterizing `InfoFlow` of the reduced operand."""
return params.compose_within(
enclosing_func_args=(sp.Integer(axis),),
)
def transform_output(self, sym: sim_symbols.SimSymbol) -> sim_symbols.SimSymbol:
"""Transform the domain of the output symbol.
Parameters:
dom: Symbolic set representing the original output symbol's domain.
info: Characterization of the original expression.
dim: Dimension symbol being reduced away.
"""
RO = ReduceOperation
match self:
# Summary
case RO.Count:
return sym.update(
sym_name=sim_symbols.SimSymbolName.Count,
mathtype=MT.Integer,
physical_type=PT.NonPhysical,
unit=None,
rows=1,
cols=1,
domain=spux.BlessedSet(sp.Naturals0),
)
# Statistics
case (
RO.Mean
| RO.Std
| RO.Var
| RO.Min
| RO.Q25
| RO.Median
| RO.Q75
| RO.Max
| RO.P2P
| RO.Z15ToZ15
| RO.Z30ToZ30
):
## -> Stats are enclosed by the original domain.
return sym
# Reductions
case RO.Sum:
return sym.update(
domain=spux.BlessedSet(sym.mathtype.symbolic_set),
)
case RO.Prod:
return sym.update(
domain=spux.BlessedSet(sym.mathtype.symbolic_set),
)

View File

@ -19,6 +19,7 @@ import typing as typ
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
from blender_maxwell.utils import logger, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@ -66,14 +67,12 @@ class TransformOperation(enum.StrEnum):
FT1D = enum.auto()
InvFT1D = enum.auto()
# TODO: Affine
## TODO
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name."""
TO = TransformOperation
return {
# Covariant Transform
@ -99,9 +98,11 @@ class TransformOperation(enum.StrEnum):
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
"""Given an integer index, generate an element that conforms to the requirements of `bpy.props.EnumProperty.items`."""
TO = TransformOperation
return (
str(self),
@ -111,6 +112,17 @@ class TransformOperation(enum.StrEnum):
i,
)
@staticmethod
def bl_enum_elements(info: ct.InfoFlow) -> list[ct.BLEnumElement]:
"""Generate a list of guaranteed-valid operations based on the passed `InfoFlow`s.
Returns a `bpy.props.EnumProperty.items`-compatible list.
"""
return [
operation.bl_enum_element(i)
for i, operation in enumerate(TransformOperation.by_info(info))
]
####################
# - Methods
####################
@ -300,14 +312,16 @@ class TransformOperation(enum.StrEnum):
physical_type=physical_type,
unit=unit,
),
ct.RangeFlow.try_from_array(ct.ArrayFlow(values=data_col, unit=unit)),
ct.RangeFlow.try_from_array(
ct.ArrayFlow(jax_bytes=data_col, unit=unit)
),
).slice_dim(info.last_dim, (1, len(info.dims[info.last_dim]), 1)),
# Fold
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
mathtype=spux.MathType.Complex
),
TO.DimToVec: lambda: info.fold_last_input(),
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
TO.DimToVec: lambda: info.fold_last_input,
TO.DimsToMat: lambda: info.fold_last_input.fold_last_input,
# Fourier
TO.FT1D: lambda: info.replace_dim(
dim,
@ -333,4 +347,38 @@ class TransformOperation(enum.StrEnum):
info.dims[dim].bound_inv_fourier_transform,
],
),
}[self]()
}[self]().update_output(
domain=self.transform_output_domain(info.output.domain, info, dim)
)
def transform_output_domain(
self, dom: spux.BlessedSet, info: ct.InfoFlow, dim: sim_symbols.SimSymbol | None
) -> sp.Set:
"""Transform the domain of the output symbol.
Parameters:
dom: Symbolic set representing the original output symbol's domain.
info: Characterization of the original expression.
dim: Dimension symbol being reduced away.
"""
TO = TransformOperation
match self:
# Summary
case (
TO.FreqToVacWL
| TO.VacWLToFreq
| TO.ConvertIdxUnit
| TO.SetIdxUnit
| TO.FirstColToFirstIdx
| TO.DimToVec
| TO.DimsToMat
| TO.IntDimToComplex
):
return dom
case TO.FT1D if dim is not None:
return info[dim].bound_fourier_transform.symbolic_set
case TO.InvFT1D if dim is not None:
return info[dim].bound_inv_fourier_transform.symbolic_set

View File

@ -16,35 +16,38 @@
from . import (
analysis,
bounds,
inputs,
mediums,
monitors,
outputs,
simulations,
solvers,
sources,
structures,
utilities,
)
BL_REGISTER = [
*analysis.BL_REGISTER,
*utilities.BL_REGISTER,
*inputs.BL_REGISTER,
*solvers.BL_REGISTER,
*outputs.BL_REGISTER,
*sources.BL_REGISTER,
*mediums.BL_REGISTER,
*structures.BL_REGISTER,
*bounds.BL_REGISTER,
*monitors.BL_REGISTER,
*simulations.BL_REGISTER,
]
BL_NODES = {
**analysis.BL_NODES,
**utilities.BL_NODES,
**inputs.BL_NODES,
**solvers.BL_NODES,
**outputs.BL_NODES,
**sources.BL_NODES,
**mediums.BL_NODES,
**structures.BL_NODES,
**bounds.BL_NODES,
**monitors.BL_NODES,
**simulations.BL_NODES,
}

View File

@ -24,9 +24,9 @@ import bpy
import jax.numpy as jnp
import sympy.physics.units as spu
import tidy3d as td
import xarray
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
@ -36,33 +36,44 @@ log = logger.get(__name__)
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
RealizedSymsVals: typ.TypeAlias = tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...]
SimDataArray: typ.TypeAlias = dict[RealizedSymsVals, td.SimulationData]
SimDataArrayInfo: typ.TypeAlias = dict[RealizedSymsVals, typ.Any]
FK = ct.FlowKind
FS = ct.FlowSignal
####################
# - Monitor Label Arrays
# - Monitor Labelling
####################
def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
def valid_monitor_attrs(
example_sim_data: td.SimulationData, monitor_name: str
) -> tuple[str, ...]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
"""
monitor_data = sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type
monitor_data = example_sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type.removesuffix('Data')
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
return [
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
return tuple(
[
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
)
case 'Permittivity':
return ['eps_xx', 'eps_yy', 'eps_zz']
return ('eps_xx', 'eps_yy', 'eps_zz')
case 'Flux' | 'FluxTime':
return ['flux']
return ('flux',)
case (
'FieldProjectionAngle'
@ -70,140 +81,246 @@ def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return [
return (
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
]
def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
"""Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
xarr = getattr(monitor_data, monitor_attr, None)
if xarr is None:
return None
def mk_idx_array(axis: str) -> ct.RangeFlow | ct.ArrayFlow:
return ct.RangeFlow.try_from_array(
ct.ArrayFlow(
values=xarr.get_index(axis).values,
unit=symbols[axis].unit,
is_sorted=True,
)
raise TypeError
####################
# - Extract InfoFlow
####################
MONITOR_SYMBOLS: dict[str, sim_symbols.SimSymbol] = {
# Field Label
'EH*': sim_symbols.sim_axis_idx(None),
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
'z': sim_symbols.space_z(spu.micrometer),
# Spherical
'r': sim_symbols.ang_r(spu.micrometer),
'theta': sim_symbols.ang_theta(spu.radian),
'phi': sim_symbols.ang_phi(spu.radian),
# Freq|Time
'f': sim_symbols.freq(spu.hertz),
't': sim_symbols.t(spu.second),
# Power Flux
'flux': sim_symbols.flux(spu.watt),
# Wavevector
'ux': sim_symbols.dir_x(spu.watt),
'uy': sim_symbols.dir_y(spu.watt),
# Diffraction Orders
'orders_x': sim_symbols.diff_order_x(None),
'orders_y': sim_symbols.diff_order_y(None),
# Cartesian Fields
'field': sim_symbols.field_e(spu.volt / spu.micrometer), ## TODO: H???
'field_e': sim_symbols.field_e(spu.volt / spu.micrometer),
'field_h': sim_symbols.field_h(spu.ampere / spu.micrometer),
}
def _mk_idx_array(xarr: xarray.DataArray, axis: str) -> ct.RangeFlow | ct.ArrayFlow:
return ct.RangeFlow.try_from_array(
ct.ArrayFlow(
jax_bytes=xarr.get_index(axis).values,
unit=MONITOR_SYMBOLS[axis].unit,
is_sorted=True,
)
)
# Compute InfoFlow from XArray
symbols = {
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
'z': sim_symbols.space_z(spu.micrometer),
# Spherical
'r': sim_symbols.ang_r(spu.micrometer),
'theta': sim_symbols.ang_theta(spu.radian),
'phi': sim_symbols.ang_phi(spu.radian),
# Freq|Time
'f': sim_symbols.freq(spu.hertz),
't': sim_symbols.t(spu.second),
# Power Flux
'flux': sim_symbols.flux(spu.watt),
# Cartesian Fields
'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Spherical Fields
'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
# Wavevector
'ux': sim_symbols.dir_x(spu.watt),
'uy': sim_symbols.dir_y(spu.watt),
# Diffraction Orders
'orders_x': sim_symbols.diff_order_x(None),
'orders_y': sim_symbols.diff_order_y(None),
}
match monitor_data.type:
def output_symbol_by_type(monitor_type: str) -> sim_symbols.SimSymbol:
match monitor_type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return MONITOR_SYMBOLS['field_e']
case 'FieldTime':
return MONITOR_SYMBOLS['field']
case 'Flux':
return MONITOR_SYMBOLS['flux']
case 'FluxTime':
return MONITOR_SYMBOLS['flux']
case 'FieldProjectionAngle':
return MONITOR_SYMBOLS['field']
case 'FieldProjectionKSpace':
return MONITOR_SYMBOLS['field']
case 'Diffraction':
return MONITOR_SYMBOLS['field']
return None
def _extract_info(
example_xarr: xarray.DataArray,
monitor_type: str,
monitor_attrs: tuple[str, ...],
batch_dims: dict[sim_symbols.SimSymbol, ct.RangeFlow | ct.ArrayFlow],
) -> ct.InfoFlow | None:
log.debug([monitor_type, monitor_attrs, batch_dims])
mk_idx_array = functools.partial(_mk_idx_array, example_xarr)
match monitor_type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['f']: mk_idx_array('f'),
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['x']: mk_idx_array('x'),
MONITOR_SYMBOLS['y']: mk_idx_array('y'),
MONITOR_SYMBOLS['z']: mk_idx_array('z'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['field_e'],
)
case 'FieldTime':
return ct.InfoFlow(
dims={
symbols['x']: mk_idx_array('x'),
symbols['y']: mk_idx_array('y'),
symbols['z']: mk_idx_array('z'),
symbols['t']: mk_idx_array('t'),
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['x']: mk_idx_array('x'),
MONITOR_SYMBOLS['y']: mk_idx_array('y'),
MONITOR_SYMBOLS['z']: mk_idx_array('z'),
MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['field'],
)
case 'Flux':
return ct.InfoFlow(
dims={
symbols['f']: mk_idx_array('f'),
dims=batch_dims
| {
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['flux'],
)
case 'FluxTime':
return ct.InfoFlow(
dims={
symbols['t']: mk_idx_array('t'),
dims=batch_dims
| {
MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['flux'],
)
case 'FieldProjectionAngle':
return ct.InfoFlow(
dims={
symbols['r']: mk_idx_array('r'),
symbols['theta']: mk_idx_array('theta'),
symbols['phi']: mk_idx_array('phi'),
symbols['f']: mk_idx_array('f'),
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['r']: mk_idx_array('r'),
MONITOR_SYMBOLS['theta']: mk_idx_array('theta'),
MONITOR_SYMBOLS['phi']: mk_idx_array('phi'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['field'],
)
case 'FieldProjectionKSpace':
return ct.InfoFlow(
dims={
symbols['ux']: mk_idx_array('ux'),
symbols['uy']: mk_idx_array('uy'),
symbols['r']: mk_idx_array('r'),
symbols['f']: mk_idx_array('f'),
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['ux']: mk_idx_array('ux'),
MONITOR_SYMBOLS['uy']: mk_idx_array('uy'),
MONITOR_SYMBOLS['r']: mk_idx_array('r'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['field'],
)
case 'Diffraction':
return ct.InfoFlow(
dims={
symbols['orders_x']: mk_idx_array('orders_x'),
symbols['orders_y']: mk_idx_array('orders_y'),
symbols['f']: mk_idx_array('f'),
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['orders_x']: mk_idx_array('orders_x'),
MONITOR_SYMBOLS['orders_y']: mk_idx_array('orders_y'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=symbols[monitor_attr],
output=MONITOR_SYMBOLS['field'],
)
return None
raise TypeError
def extract_monitor_xarrs(
monitor_datas: dict[RealizedSymsVals, typ.Any], monitor_attrs: tuple[str, ...]
) -> dict[RealizedSymsVals, ct.InfoFlow]:
return {
syms_vals: {
monitor_attr: getattr(monitor_data, monitor_attr, None)
for monitor_attr in monitor_attrs
}
for syms_vals, monitor_data in monitor_datas.items()
}
def extract_info(
monitor_datas: dict[RealizedSymsVals, typ.Any], monitor_attrs: tuple[str, ...]
) -> dict[RealizedSymsVals, ct.InfoFlow]:
"""Extract an InfoFlow describing monitor data from a batch of simulations."""
# Extract Dimension from Batched Values
## -> Comb the data to expose each symbol's realized values as an array.
## -> Each symbol: array can then become a dimension.
## -> These are the "batch dimensions", which allows indexing across sims.
## -> The retained sim symbol provides semantic index coordinates.
example_syms_vals = next(iter(monitor_datas.keys()))
syms = example_syms_vals[0] if example_syms_vals != () else ()
vals_per_sym_pos = (
[vals for _, vals in monitor_datas] if example_syms_vals != () else []
)
_batch_dims = dict(
zip(
syms,
zip(*vals_per_sym_pos, strict=True),
strict=True,
)
)
batch_dims = {
sym: ct.RangeFlow.try_from_array(
ct.ArrayFlow(
jax_bytes=vals,
unit=sym.unit,
is_sorted=True,
)
)
for sym, vals in _batch_dims.items()
}
# Extract Example Monitor Data | XArray
## -> We presume all monitor attributes have the exact same dims + output.
## -> Because of this, we only need one "example" xarray.
## -> This xarray will be used to extract dimensional coordinates...
## -> ...Presuming that these coords will generalize.
example_monitor_data = next(iter(monitor_datas.values()))
monitor_datas_xarrs = extract_monitor_xarrs(monitor_datas, monitor_attrs)
# Extract XArray for Each Monitor Attribute
example_monitor_data_xarrs = next(iter(monitor_datas_xarrs.values()))
example_xarr = next(iter(example_monitor_data_xarrs.values()))
# Extract InfoFlow of First
## -> All of the InfoFlows should be identical...
## -> ...Apart from the batched dimensions.
return _extract_info(
example_xarr,
example_monitor_data.type.removesuffix('Data'),
monitor_attrs,
batch_dims,
)
####################
@ -224,73 +341,139 @@ class ExtractDataNode(base.MaxwellSimNode):
bl_label = 'Extract'
input_socket_sets: typ.ClassVar = {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
'Single': {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
},
'Batch': {
'Sim Datas': sockets.MaxwellFDTDSimDataSocketDef(active_kind=FK.Array),
},
}
# output_sockets: typ.ClassVar = {
# 'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
# }
output_socket_sets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Single': {
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
'Log': sockets.StringSocketDef(),
},
'Batch': {
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
'Logs': sockets.StringSocketDef(active_kind=FK.Array),
},
}
####################
# - Properties: Monitor Name
# - Properties: Sim Datas
####################
@events.on_value_changed(
socket_name='Sim Data',
input_sockets={'Sim Data'},
input_sockets_optional={'Sim Data': True},
socket_name={'Sim Data': FK.Value, 'Sim Datas': FK.Array},
)
def on_sim_data_changed(self, input_sockets) -> None: # noqa: D102
has_sim_data = not ct.FlowSignal.check(input_sockets['Sim Data'])
if has_sim_data:
self.sim_data = bl_cache.Signal.InvalidateCache
def on_sim_datas_changed(self) -> None: # noqa: D102
self.sim_datas = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def sim_data(self) -> td.SimulationData | None:
@bl_cache.cached_bl_property(depends_on={'active_socket_set'})
def sim_datas(self) -> list[td.SimulationData] | None:
"""Extracts the simulation data from the input socket.
Return:
Either the simulation data, if available, or None.
"""
sim_data = self._compute_input(
'Sim Data', kind=ct.FlowKind.Value, optional=True
)
has_sim_data = not ct.FlowSignal.check(sim_data)
if has_sim_data:
return sim_data
## TODO: Check that syms are identical for all (aka. that we have a batch)
if self.active_socket_set == 'Single':
sim_data = self._compute_input('Sim Data', kind=FK.Value)
has_sim_data = not FS.check(sim_data)
if has_sim_data:
# Embedded Symbolic Realizations
## -> ['realizations'] contains a 2-tuple
## -> First should be the dict-dump of a SimSymbol.
## -> Second should be either a value, or a list of values.
if 'realizations' in sim_data.attrs:
raw_realizations = sim_data.attrs['realizations']
syms_vals = {
sim_symbols.SimSymbol(**raw_sym): raw_val
if not isinstance(raw_val, tuple | list)
else jnp.array(raw_val)
for raw_sym, raw_val in raw_realizations
}
return {syms_vals: sim_data}
# No Embedded Realizations
return {(): sim_data}
if self.active_socket_set == 'Batch':
_sim_datas = self._compute_input('Sim Datas', kind=FK.Value)
has_sim_datas = not FS.check(_sim_datas)
if has_sim_datas:
sim_datas = {}
for sim_data in sim_datas:
# Embedded Symbolic Realizations
## -> ['realizations'] contains a 2-tuple
## -> First should be the dict-dump of a SimSymbol.
## -> Second should be either a value, or a list of values.
if 'realizations' in sim_data.attrs:
raw_realizations = sim_data.attrs['realizations']
syms = {
sim_symbols.SimSymbol(**raw_sym): raw_val
if not isinstance(raw_val, tuple | list)
else jnp.array(raw_val)
for raw_sym, raw_val in raw_realizations
}
sim_datas |= {syms_vals: sim_data}
# No Embedded Realizations
sim_datas |= {(): sim_data}
return None
@bl_cache.cached_bl_property(depends_on={'sim_data'})
def sim_data_monitor_nametype(self) -> dict[str, str] | None:
"""Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
@bl_cache.cached_bl_property(depends_on={'sim_datas'})
def example_sim_data(self) -> list[td.SimulationData] | None:
"""Extracts a single, example simulation data from the input socket.
All simulation datas share certain properties, ex. names and types of monitors.
Therefore, we may often only need an example simulation data object.
Return:
Either the simulation data, if available, or None.
"""
if self.sim_datas:
return next(iter(self.sim_datas.values()))
return None
####################
# - Properties: Monitor Name
####################
@bl_cache.cached_bl_property(depends_on={'example_sim_data'})
def monitor_types(self) -> dict[str, str] | None:
"""Dictionary from monitor names on `self.sim_datas` to their associated type name (with suffix 'Data' removed).
Return:
The name to type of monitors in the simulation data.
"""
if self.sim_data is not None:
if self.example_sim_data is not None:
return {
monitor_name: monitor_data.type.removesuffix('Data')
for monitor_name, monitor_data in self.sim_data.monitor_data.items()
for monitor_name, monitor_data in self.example_sim_data.monitor_data.items()
}
return None
monitor_name: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_monitor_names(),
cb_depends_on={'sim_data_monitor_nametype'},
cb_depends_on={'monitor_types'},
)
def search_monitor_names(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
Notes:
Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametypes`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns:
Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
"""
if self.sim_data_monitor_nametype is not None:
if self.monitor_types is not None:
return [
(
monitor_name,
@ -300,12 +483,32 @@ class ExtractDataNode(base.MaxwellSimNode):
i,
)
for i, (monitor_name, monitor_type) in enumerate(
self.sim_data_monitor_nametype.items()
self.monitor_types.items()
)
]
return []
####################
# - Properties: Monitor Information
####################
@bl_cache.cached_bl_property(depends_on={'sim_datas', 'monitor_name'})
def monitor_datas(self) -> SimDataArrayInfo | None:
"""Extract the currently selected monitor's data from all simulation datas in the batch."""
if self.sim_datas is not None and self.monitor_name is not None:
return {
syms_vals: sim_data.monitor_data.get(self.monitor_name)
for syms_vals, sim_data in self.sim_datas.items()
}
return None
@bl_cache.cached_bl_property(depends_on={'example_sim_data', 'monitor_name'})
def valid_monitor_attrs(self) -> SimDataArrayInfo | None:
"""Valid attributes of the monitor, from the example sim data under the presumption that the entire batch shares the same attribute validity."""
if self.example_sim_data is not None and self.monitor_name is not None:
return valid_monitor_attrs(self.example_sim_data, self.monitor_name)
return None
####################
# - UI
####################
@ -315,9 +518,7 @@ class ExtractDataNode(base.MaxwellSimNode):
Notes:
Called by Blender to determine the text to place in the node's header.
"""
has_sim_data = self.sim_data_monitor_nametype is not None
if has_sim_data:
if self.monitor_name is not None:
return f'Extract: {self.monitor_name}'
return self.bl_label
@ -335,114 +536,181 @@ class ExtractDataNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'monitor_name'},
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
props={'monitor_datas', 'valid_monitor_attrs'},
)
def compute_expr(
self, props: dict, input_sockets: dict
) -> ct.FuncFlow | ct.FlowSignal:
sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
def compute_extracted_data_func(self, props) -> ct.FuncFlow | FS:
"""Aggregates the selected monitor's data across all batched symbolic realizations, into a single FuncFlow."""
monitor_datas = props['monitor_datas']
valid_monitor_attrs = props['valid_monitor_attrs']
has_sim_data = not ct.FlowSignal.check(sim_data)
if monitor_datas is not None and valid_monitor_attrs is not None:
monitor_datas_xarrs = extract_monitor_xarrs(
monitor_datas, valid_monitor_attrs
)
if has_sim_data and monitor_name is not None:
monitor_data = sim_data.get(monitor_name)
if monitor_data is not None:
# Extract Valid Index Labels
## -> The first output axis will be integer-indexed.
## -> Each integer will have a string label.
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
example_monitor_data = next(iter(monitor_datas.values()))
monitor_type = example_monitor_data.type.removesuffix('Data')
output_sym = output_symbol_by_type(monitor_type)
# Extract Info
## -> We only need the output symbol.
## -> All labelled outputs have the same output SimSymbol.
info = extract_info(monitor_data, idx_labels[0])
# Stack Inner Dimensions: components | *
## -> Each realization maps to exactly one xarray.
## -> We extract its data, and wrap it into a FuncFlow.
## -> This represents the "inner" dimensions, with components|data.
## -> We then attach this singular FuncFlow to that realization.
inner_funcs = {}
for syms_vals, attr_xarrs in monitor_datas_xarrs.items():
# XArray Capture Function
## -> We can't generally capture a loop variable inline.
## -> By making a new function, we get a new scope.
def _xarr_values(xarr):
return lambda: xarr.values
# Generate FuncFlow Per Index Label
## -> We extract each XArray as an attribute of monitor_data.
## -> We then bind its values into a unique func_flow.
## -> This lets us 'stack' then all along the first axis.
func_flows = []
for idx_label in idx_labels:
xarr = getattr(monitor_data, idx_label)
func_flows.append(
ct.FuncFlow(
func=lambda xarr=xarr: xarr.values,
supports_jax=True,
)
# Bind XArray Values into FuncFlows
## -> Each monitor component has an xarray.
funcs = [
ct.FuncFlow(
func=_xarr_values(xarr),
func_output=output_sym,
supports_jax=True,
)
for xarr in attr_xarrs.values()
]
log.critical(['FUNCS', funcs])
# Single Component: No Stacking of Dimensions - *
if len(funcs) == 1:
inner_funcs[syms_vals] = funcs[0]
# Many Components: Stack Dimensions - components | *
else:
inner_funcs[syms_vals] = functools.reduce(
lambda a, b: a | b, funcs
).compose_within(
lambda els: jnp.stack(els, axis=0),
enclosing_func_output=output_sym,
)
# Concatenate and Stack Unified FuncFlow
## -> First, 'reduce' lets us __or__ all the FuncFlows together.
## -> Then, 'compose_within' lets us stack them along axis=0.
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
lambda data: jnp.stack(data, axis=0),
func_output=info.output,
)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
# Stack Batch Dims: vals0 | vals1 | ... | valsN | components | *
## -> We stack the inner-dimensional object together, backwards.
## -> Each stack prepends a new dimension.
## -> Here, everything is integer-indexed.
## -> But in the InfoFlow, a similar process contextualizes idxs.
example_syms_vals = next(iter(monitor_datas.keys()))
syms = example_syms_vals[0] if example_syms_vals != () else ()
outer_funcs = inner_funcs
log.critical(['INNER FUNCS', inner_funcs])
for _, axis in reversed(list(enumerate(syms))):
log.critical([axis, outer_funcs])
new_outer_funcs = {}
# Collect Funcs Along *vals[axis] | ...
## -> Grab ONLY up to 'axis' syms_vals.
## -> '|' all functions that share axis-deficient syms_vals.
## -> Thus, we '|' functions along the last axis.
for unreduced_syms_vals, func in outer_funcs:
reduced_syms_vals = unreduced_syms_vals[:axis]
if reduced_syms_vals in new_outer_funcs:
new_outer_funcs[reduced_syms_vals] = (
new_outer_funcs[reduced_syms_vals] | func
)
else:
new_outer_funcs[reduced_syms_vals] = func
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Params},
)
def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
# Aggregate All Collected Funcs
## -> Any functions that went through a | are stacked.
## -> Otherwise, just add a len=1 dimension.
new_reduced_outer_funcs = {
reduced_syms_vals: (
combined_func.compose_within(
lambda els: jnp.stack(els, axis=0),
enclosing_func_output=output_sym,
)
if combined_func.is_concatenated ## Went through a |
else combined_func.compose_within(
lambda el: jnp.expand_dims(el, axis=0),
enclosing_func_output=output_sym,
)
)
for reduced_syms_vals, combined_func in new_outer_funcs.items()
}
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
sim_params = input_sockets['Sim Data']
has_sim_params = not ct.FlowSignal.check(sim_params)
if has_sim_params:
return sim_params
return ct.ParamsFlow()
# Reset Outer Funcs to Axis-Deficient Reduction
## -> This effectively removes + aggregates the last axis.
## -> When the loop is done, only {(): val} will be left.
outer_funcs = new_reduced_outer_funcs
return next(iter(outer_funcs.values()))
return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={'monitor_name'},
input_sockets={'Sim Data'},
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
props={'monitor_datas', 'valid_monitor_attrs'},
)
def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
def compute_extracted_data_info(self, props) -> ct.InfoFlow | FS:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
Returns:
Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
Information describing the `Data:Func`, if available, else `FS.FlowPending`.
"""
sim_data = input_sockets['Sim Data']
monitor_name = props['monitor_name']
monitor_datas = props['monitor_datas']
valid_monitor_attrs = props['valid_monitor_attrs']
has_sim_data = not ct.FlowSignal.check(sim_data)
if monitor_datas is not None and valid_monitor_attrs is not None:
return extract_info(monitor_datas, valid_monitor_attrs)
return FS.FlowPending
if not has_sim_data or monitor_name is None:
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=FK.Params,
)
def compute_params(self) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
# Extract Data
## -> All monitor_data.<idx_label> have the exact same InfoFlow.
## -> So, just construct an InfoFlow w/prepended labelled dimension.
monitor_data = sim_data.get(monitor_name)
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
info = extract_info(monitor_data, idx_labels[0])
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
return ct.ParamsFlow()
return info.prepend_dim(sim_symbols.idx, idx_labels)
####################
# - Log: FlowKind.Value|Array
####################
@events.computes_output_socket(
'Log',
kind=FK.Value,
# Loaded
props={'sim_datas'},
)
def compute_extracted_log(self, props) -> str | FS:
"""Extract the log from a single simulation that ran."""
sim_datas = props['sim_datas']
if sim_datas is not None and len(sim_datas) == 1:
sim_data = next(iter(sim_datas.values()))
if sim_data.log is not None:
return sim_data.log
return FS.FlowPending
@events.computes_output_socket(
'Log',
kind=FK.Array,
# Loaded
props={'sim_datas'},
)
def compute_extracted_logs(self, props) -> dict[RealizedSymsVals, str] | FS:
"""Extract the log from all simulation that ran in the batch."""
sim_datas = props['sim_datas']
if sim_datas is not None and sim_datas:
return {
syms_vals: sim_data.log if sim_data.log is not None else ''
for syms_vals, sim_data in sim_datas.items()
}
return FS.FlowPending
####################

View File

@ -14,19 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import filter_math, map_math, operate_math, transform_math
from . import filter_math, map_math, operate_math, reduce_math, transform_math
BL_REGISTER = [
*operate_math.BL_REGISTER,
*map_math.BL_REGISTER,
*filter_math.BL_REGISTER,
# *reduce_math.BL_REGISTER,
*reduce_math.BL_REGISTER,
*transform_math.BL_REGISTER,
]
BL_NODES = {
**operate_math.BL_NODES,
**map_math.BL_NODES,
**filter_math.BL_NODES,
# **reduce_math.BL_NODES,
**reduce_math.BL_NODES,
**transform_math.BL_NODES,
}

View File

@ -22,13 +22,20 @@ import typing as typ
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache, sim_symbols
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
FO = math_system.FilterOperation
MT = spux.MathType
class FilterMathNode(base.MaxwellSimNode):
r"""Applies a function that operates on the shape of the array.
@ -52,10 +59,10 @@ class FilterMathNode(base.MaxwellSimNode):
bl_label = 'Filter Math'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_func_ui=False),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@ -63,50 +70,43 @@ class FilterMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
socket_name={'Expr'},
socket_name={'Expr': FK.Info},
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
inscks_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
def on_input_expr_changed(self, input_sockets) -> None: # noqa: D102
info = input_sockets['Expr']
has_info = not FS.check(info)
info_pending = FS.check_single(info, FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
"""Retrieve the input expression's `InfoFlow`."""
info = self._compute_input('Expr', kind=FK.Info)
has_info = not FS.check(info)
if has_info:
return info
return None
####################
# - Properties: Operation
####################
operation: math_system.FilterOperation = bl_cache.BLField(
operation: FO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
"""Determine all valid operations from the input expression."""
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
math_system.FilterOperation.by_info(self.expr_info)
)
]
return FO.bl_enum_elements(self.expr_info)
return []
####################
@ -122,6 +122,7 @@ class FilterMathNode(base.MaxwellSimNode):
)
def search_dims(self) -> list[ct.BLEnumElement]:
"""Determine all valid dimensions from the input expression."""
if self.expr_info is not None and self.operation is not None:
return [
(dim.name, dim.name_pretty, dim.name, '', i)
@ -131,16 +132,32 @@ class FilterMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
def dim_0(self) -> sim_symbols.SimSymbol | None:
"""The first currently active dimension, if any is selected; otherwise `None`."""
if self.expr_info is not None and self.active_dim_0 is not None:
return self.expr_info.dim_by_name(self.active_dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
def dim_1(self) -> sim_symbols.SimSymbol | None:
"""The second currently active dimension, if any is selected; otherwise `None`."""
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_by_name(self.active_dim_1)
return None
@bl_cache.cached_bl_property(depends_on={'dim_0'})
def axis_0(self) -> sim_symbols.SimSymbol | None:
"""The first currently active axis, derived from `self.dim_0`."""
if self.expr_info is not None and self.dim_0 is not None:
return self.expr_info.dim_axis(self.dim_0)
return None
@bl_cache.cached_bl_property(depends_on={'dim_1'})
def axis_1(self) -> sim_symbols.SimSymbol | None:
"""The second currently active dimension, if any is selected; otherwise `None`."""
if self.expr_info is not None and self.active_dim_1 is not None:
return self.expr_info.dim_axis(self.dim_1)
return None
####################
# - Properties: Slice
####################
@ -149,8 +166,12 @@ class FilterMathNode(base.MaxwellSimNode):
####################
# - UI
####################
def draw_label(self):
FO = math_system.FilterOperation
def draw_label(self): # noqa: PLR0911
"""Show the active filter operation in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
match self.operation:
# Slice
case FO.SliceIdx:
@ -163,10 +184,8 @@ class FilterMathNode(base.MaxwellSimNode):
case FO.Pin:
return f'Filter: Pin {self.active_dim_0}[...]'
case FO.PinIdx:
pin_idx_axis = self._compute_input(
'Axis', kind=ct.FlowKind.Value, optional=True
)
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
pin_idx_axis = self._compute_input('Index', kind=FK.Value)
has_pin_idx_axis = not FS.check(pin_idx_axis)
if has_pin_idx_axis:
return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
return self.bl_label
@ -179,6 +198,11 @@ class FilterMathNode(base.MaxwellSimNode):
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's properties inside of the node itself.
Parameters:
layout: UI target for drawing.
"""
layout.prop(self, self.blfields['operation'], text='')
if self.operation is not None:
@ -190,7 +214,7 @@ class FilterMathNode(base.MaxwellSimNode):
row.prop(self, self.blfields['active_dim_0'], text='')
row.prop(self, self.blfields['active_dim_1'], text='')
if self.operation is math_system.FilterOperation.SliceIdx:
if self.operation is FO.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
####################
@ -198,220 +222,190 @@ class FilterMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
socket_name='Expr',
prop_name={'operation', 'dim_0', 'dim_1'},
socket_name={'Expr': FK.Info},
prop_name={'operation', 'dim_0'},
# Loaded
props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
props={'operation', 'dim_0'},
inscks_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr'},
)
def on_pin_factors_changed(self, props: dict, input_sockets: dict):
"""Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`.
To "pin" an axis, a particular index must be chosen to "extract".
One might choose axes of length 1 ("squeeze"), choose a particular index, or choose a coordinate that maps to a particular index.
Those last two options requires more information from the user: Which index?
Which coordinate?
To answer these questions, we create an appropriate loose input socket containing this data, so the user can make their decision.
"""
def on_pin_factors_changed(self, props, input_sockets) -> None:
"""Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`."""
info = input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
if not has_info:
return
has_info = not FS.check(info)
dim_0 = props['dim_0']
# Loose Sockets: Pin Dim by-Value
## -> Works with continuous / discrete indexes.
## -> The user will be given a socket w/correct mathtype, unit, etc. .
if (
props['operation'] is math_system.FilterOperation.Pin
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Value')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != dim.physical_type
or current_bl_socket.mathtype != dim.mathtype
operation = props['operation']
match operation:
case FO.Pin if (
has_info
and dim_0 is not None
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
):
self.loose_input_sockets = {
'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
physical_type=dim.physical_type,
mathtype=dim.mathtype,
default_unit=dim.unit,
active_kind=FK.Value,
**dim_0.expr_info,
),
}
# Loose Sockets: Pin Dim by-Value
## -> Works with discrete points / labelled integers.
elif (
props['operation'] is math_system.FilterOperation.PinIdx
and dim_0 is not None
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
):
dim = dim_0
current_bl_socket = self.loose_input_sockets.get('Axis')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.Value
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.mathtype != spux.MathType.Integer
case FO.PinIdx if (
has_info
and dim_0 is not None
and (info.has_idx_labels(dim_0) or info.has_idx_discrete(dim_0))
):
self.loose_input_sockets = {
'Axis': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
mathtype=spux.MathType.Integer,
)
'Index': sockets.ExprSocketDef(
active_kind=FK.Value,
output_name=dim_0.expr_info['output_name'],
mathtype=MT.Integer,
abs_min=0,
abs_max=len(info.dims[dim_0]) - 1,
),
}
# No Loose Value: Remove Input Sockets
elif self.loose_input_sockets:
self.loose_input_sockets = {}
case _ if self.loose_input_sockets:
self.loose_input_sockets = {}
####################
# - FlowKind.Value|Func
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}},
kind=FK.Func,
# Loaded
props={'operation', 'axis_0', 'axis_1', 'slice_tuple'},
inscks_kinds={'Expr': FK.Func},
)
def compute_lazy_func(self, props: dict, input_sockets: dict):
operation = props['operation']
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
def compute_func(self, props, input_sockets) -> None:
"""Filter operation on lazy-defined input expression."""
lazy_func = input_sockets['Expr']
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_info = not ct.FlowSignal.check(info)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
axis_0 = props['axis_0']
axis_1 = props['axis_1']
slice_tuple = props['slice_tuple']
if (
has_lazy_func
and has_info
and operation is not None
and operation.are_dims_valid(info, dim_0, dim_1)
):
axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
return lazy_func.compose_within(
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
enclosing_func_args=operation.func_args,
enclosing_func_output=info.output,
supports_jax=True,
operation = props['operation']
if operation is not None:
new_func = operation.transform_func(
lazy_func, axis_0, axis_1=axis_1, slice_tuple=slice_tuple
)
return ct.FlowSignal.FlowPending
if new_func is not None:
return new_func
return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={
'dim_0',
'dim_1',
'operation',
'slice_tuple',
},
input_sockets={'Expr', 'Dim'},
input_socket_kinds={
'Expr': ct.FlowKind.Info,
'Dim': {ct.FlowKind.Func, ct.FlowKind.Params, ct.FlowKind.Info},
inscks_kinds={
'Expr': FK.Info,
'Value': {FK.Func, FK.Params},
'Index': {FK.Func, FK.Params},
},
input_sockets_optional={'Dim': True},
input_sockets_optional={'Index', 'Value'},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
operation = props['operation']
"""Transform `InfoFlow` based on the current filtering operation."""
info = input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
# Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if has_info and operation is not None:
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
return ct.FlowSignal.FlowPending
operation = props['operation']
match operation:
# Slice
case FO.Slice | FO.SliceIdx if dim_0 is not None:
slice_tuple = props['slice_tuple']
return operation.transform_info(info, dim_0, slice_tuple=slice_tuple)
# Pin
case FO.PinLen1 if dim_0 is not None:
return operation.transform_info(info, dim_0)
case FO.Pin if dim_0 is not None:
pinned_value = events.realize_known(
input_sockets['Value'], conformed=True
)
if pinned_value is not None:
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)
return operation.transform_info(
info, dim_0, pin_idx=nearest_idx_to_value
)
case FO.PinIdx if dim_0 is not None:
pinned_idx = int(events.realize_known(input_sockets['Index']))
if pinned_idx is not None:
return operation.transform_info(info, dim_0, pin_idx=pinned_idx)
# Swizzle
case FO.Swap if dim_0 is not None and dim_1 is not None:
return operation.transform_info(info, dim_0, dim_1)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Expr', 'Value', 'Axis'},
input_socket_kinds={
'Expr': {ct.FlowKind.Info, ct.FlowKind.Params},
inscks_kinds={
'Value': {FK.Func, FK.Params},
'Index': {FK.Func, FK.Params},
'Expr': {FK.Info, FK.Params},
},
input_sockets_optional={'Value': True, 'Axis': True},
input_sockets_optional={'Value', 'Index'},
)
def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow:
operation = props['operation']
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
def compute_params(self, props, input_sockets) -> ct.ParamsFlow:
"""Compute tracked function argument parameters of input parameters."""
info = input_sockets['Expr'][FK.Info]
params = input_sockets['Expr'][FK.Params]
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
# Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
if (
has_info
and has_params
and operation is not None
and operation.are_dims_valid(info, dim_0, dim_1)
):
# Retrieve Pinned Value
pinned_value = input_sockets['Value']
has_pinned_value = not ct.FlowSignal.check(pinned_value)
pinned_axis = input_sockets['Axis']
has_pinned_axis = not ct.FlowSignal.check(pinned_axis)
operation = props['operation']
match operation:
# *
case FO.Slice | FO.SliceIdx | FO.PinLen1 | FO.Swap:
return params
# Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search.
if (
props['operation'] is math_system.FilterOperation.Pin
and has_pinned_value
):
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
# Pin
case FO.Pin:
pinned_value = events.realize_known(
input_sockets['Value'], conformed=True
)
if pinned_value is not None:
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)
return params.compose_within(
enclosing_func_args=(sp.Integer(nearest_idx_to_value),),
)
return params.compose_within(
enclosing_arg_targets=[sim_symbols.idx(None)],
enclosing_func_args=[sp.S(nearest_idx_to_value)],
)
case FO.PinIdx:
pinned_idx = events.realize_known(input_sockets['Index'])
if pinned_idx is not None:
return params.compose_within(
enclosing_func_args=(sp.Integer(pinned_idx),),
)
# Pin by-Index
if (
props['operation'] is math_system.FilterOperation.PinIdx
and has_pinned_axis
):
return params.compose_within(
enclosing_arg_targets=[sim_symbols.idx(None)],
enclosing_func_args=[sp.S(pinned_axis)],
)
return params
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################

View File

@ -21,6 +21,7 @@ import typing as typ
import bpy
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
@ -28,6 +29,11 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MO = math_system.MapOperation
MT = spux.MathType
class MapMathNode(base.MaxwellSimNode):
r"""Applies a function by-structure to the data.
@ -102,7 +108,7 @@ class MapMathNode(base.MaxwellSimNode):
The name and type of the available symbol is clearly shown, and most valid `sympy` expressions that you would expect to work, should work.
Use of expressions generally imposes no performance penalty: Just like the baked-in operations, it is compiled to a high-performance `jax` function.
Thus, it participates in the `ct.FlowKind.Func` composition chain.
Thus, it participates in the `FK.Func` composition chain.
Attributes:
@ -113,69 +119,76 @@ class MapMathNode(base.MaxwellSimNode):
bl_label = 'Map Math'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
# - Properties
# - Properties: Incoming InfoFlow
####################
@events.on_value_changed(
# Trigger
socket_name={'Expr'},
socket_name={'Expr': FK.Info},
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
inscks_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
def on_input_expr_changed(self, input_sockets) -> None: # noqa: D102
info = input_sockets['Expr']
has_info = not FS.check(info)
info_pending = FS.check_single(info, FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
"""Retrieve the input expression's `InfoFlow`."""
info = self._compute_input('Expr', kind=FK.Info)
has_info = not FS.check(info)
if has_info:
return info
return None
operation: math_system.MapOperation = bl_cache.BLField(
####################
# - Property: Operation
####################
operation: MO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
"""Retrieve valid operations based on the input `InfoFlow`."""
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
math_system.MapOperation.by_expr_info(self.expr_info)
)
]
return MO.bl_enum_elements(self.expr_info)
return []
####################
# - UI
####################
def draw_label(self):
"""Show the current operation (if any) in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
return 'Map: ' + math_system.MapOperation.to_name(self.operation)
return 'Map: ' + MO.to_name(self.operation)
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's properties inside of the node itself.
Parameters:
layout: UI target for drawing.
"""
layout.prop(self, self.blfields['operation'], text='')
####################
@ -183,91 +196,81 @@ class MapMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
props={'operation'},
input_sockets={'Expr'},
inscks_kinds={'Expr': FK.Value},
)
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
operation = props['operation']
def compute_value(self, props, input_sockets) -> ct.ValueFlow | FS:
"""Mapping operation on symbolic input expression."""
expr = input_sockets['Expr']
has_expr_value = not ct.FlowSignal.check(expr)
# Compute Sympy Function
## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None:
operation = props['operation']
if operation is not None:
return operation.sp_func(expr)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=FK.Func,
# Loaded
kind=ct.FlowKind.Func,
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': ct.FlowKind.Func,
inscks_kinds={
'Expr': FK.Func,
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
expr = input_sockets['Expr']
output_info = output_sockets['Expr']
has_expr = not ct.FlowSignal.check(expr)
has_output_info = not ct.FlowSignal.check(output_info)
def compute_func(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Mapping operation on lazy-defined input expression."""
func = input_sockets['Expr']
operation = props['operation']
if has_expr and operation is not None:
return expr.compose_within(
operation.jax_func,
enclosing_func_output=output_info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
if operation is not None:
return operation.transform_func(func)
return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
inscks_kinds={
'Expr': FK.Info,
},
)
def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
operation = props['operation']
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
"""Transform the info chracterization of the input."""
info = input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
if has_info and operation is not None:
operation = props['operation']
if operation is not None:
return operation.transform_info(info)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params},
input_socket_kinds={'Expr': FK.Params},
)
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
if has_params:
return input_sockets['Expr']
return ct.FlowSignal.FlowPending
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Transform the parameters of the input."""
params = input_sockets['Expr']
operation = props['operation']
if operation is not None:
return operation.transform_params(params)
return FS.FlowPending
####################

View File

@ -24,6 +24,7 @@ import typing as typ
import bpy
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import math_system, sockets
@ -31,6 +32,11 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
BO = math_system.BinaryOperation
MT = spux.MathType
class OperateMathNode(base.MaxwellSimNode):
r"""Applies a binary function between two expressions.
@ -46,13 +52,11 @@ class OperateMathNode(base.MaxwellSimNode):
bl_label = 'Operate Math'
input_sockets: typ.ClassVar = {
'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr L': sockets.ExprSocketDef(active_kind=FK.Func),
'Expr R': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func, show_info_columns=True
),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_info_columns=True),
}
####################
@ -60,25 +64,21 @@ class OperateMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
socket_name={'Expr L', 'Expr R'},
socket_name={'Expr L': FK.Info, 'Expr R': FK.Info},
# Loaded
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={'Expr L': ct.FlowKind.Info, 'Expr R': ct.FlowKind.Info},
input_sockets_optional={'Expr L': True, 'Expr R': True},
inscks_kinds={'Expr L': FK.Info, 'Expr R': FK.Info},
input_sockets_optional={'Expr L', 'Expr R'},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info_l = not ct.FlowSignal.check(input_sockets['Expr L'])
has_info_r = not ct.FlowSignal.check(input_sockets['Expr R'])
def on_input_exprs_changed(self, input_sockets) -> None:
"""Queue an update of the cached expression infos whenever data changed."""
has_info_l = not FS.check(input_sockets['Expr L'])
has_info_r = not FS.check(input_sockets['Expr R'])
info_l_pending = ct.FlowSignal.check_single(
input_sockets['Expr L'], ct.FlowSignal.FlowPending
)
info_r_pending = ct.FlowSignal.check_single(
input_sockets['Expr R'], ct.FlowSignal.FlowPending
)
info_l_pending = FS.check_single(input_sockets['Expr L'], FS.FlowPending)
info_r_pending = FS.check_single(input_sockets['Expr R'], FS.FlowPending)
if has_info_l and has_info_r and not info_l_pending and not info_r_pending:
self.expr_infos = bl_cache.Signal.InvalidateCache
@ -86,21 +86,20 @@ class OperateMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def expr_infos(self) -> tuple[ct.InfoFlow, ct.InfoFlow] | None:
"""Computed `InfoFlow`s of both expressions."""
info_l = self._compute_input('Expr L', kind=ct.FlowKind.Info)
info_r = self._compute_input('Expr R', kind=ct.FlowKind.Info)
info_l = self._compute_input('Expr L', kind=FK.Info)
info_r = self._compute_input('Expr R', kind=FK.Info)
has_info_l = not ct.FlowSignal.check(info_l)
has_info_r = not ct.FlowSignal.check(info_r)
has_info_l = not FS.check(info_l)
has_info_r = not FS.check(info_r)
if has_info_l and has_info_r:
return (info_l, info_r)
return None
####################
# - Property: Operation
####################
operation: math_system.BinaryOperation = bl_cache.BLField(
operation: BO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_infos'},
)
@ -108,7 +107,7 @@ class OperateMathNode(base.MaxwellSimNode):
def search_operations(self) -> list[ct.BLEnumElement]:
"""Retrieve valid operations based on the input `InfoFlow`s."""
if self.expr_infos is not None:
return math_system.BinaryOperation.bl_enum_elements(*self.expr_infos)
return BO.bl_enum_elements(*self.expr_infos)
return []
####################
@ -121,12 +120,12 @@ class OperateMathNode(base.MaxwellSimNode):
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
return 'Op: ' + math_system.BinaryOperation.to_name(self.operation)
return 'Op: ' + BO.to_name(self.operation)
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw node properties in the node.
"""Draw properties in the node.
Parameters:
col: UI target for drawing.
@ -138,70 +137,59 @@ class OperateMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Expr R': ct.FlowKind.Value,
inscks_kinds={
'Expr L': FK.Value,
'Expr R': FK.Value,
},
)
def compute_value(self, props: dict, input_sockets: dict):
def compute_value(self, props, input_sockets) -> ct.InfoFlow | FS:
"""Binary operation on two symbolic input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
has_expr_l_value = not ct.FlowSignal.check(expr_l)
has_expr_r_value = not ct.FlowSignal.check(expr_r)
operation = props['operation']
if has_expr_l_value and has_expr_r_value and operation is not None:
if operation is not None:
return operation.sp_func([expr_l, expr_r])
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Func,
'Expr R': ct.FlowKind.Func,
inscks_kinds={
'Expr L': FK.Func,
'Expr R': FK.Func,
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
)
def compute_func(self, props, input_sockets, output_sockets):
def compute_func(self, props, input_sockets) -> ct.InfoFlow | FS:
"""Binary operation on two lazy-defined input expressions."""
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
output_info = output_sockets['Expr']
has_expr_l = not ct.FlowSignal.check(expr_l)
has_expr_r = not ct.FlowSignal.check(expr_r)
has_output_info = not ct.FlowSignal.check(output_info)
operation = props['operation']
if operation is not None and has_expr_l and has_expr_r and has_output_info:
if operation is not None:
return self.operation.transform_funcs(expr_l, expr_r)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Info,
'Expr R': ct.FlowKind.Info,
'Expr L': FK.Info,
'Expr R': FK.Info,
},
)
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
@ -209,44 +197,43 @@ class OperateMathNode(base.MaxwellSimNode):
info_l = input_sockets['Expr L']
info_r = input_sockets['Expr R']
has_info_l = not ct.FlowSignal.check(info_l)
has_info_r = not ct.FlowSignal.check(info_r)
has_info_l = not FS.check(info_l)
has_info_r = not FS.check(info_r)
operation = props['operation']
if (
has_info_l and has_info_r and operation is not None
# and operation in BO.by_infos(info_l, info_r)
has_info_l
and has_info_r
and operation is not None
and operation in BO.from_infos(info_l, info_r)
):
return operation.transform_infos(info_l, info_r)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Params,
'Expr R': ct.FlowKind.Params,
'Expr L': FK.Params,
'Expr R': FK.Params,
},
)
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Merge the lazy input parameters."""
params_l = input_sockets['Expr L']
params_r = input_sockets['Expr R']
has_params_l = not ct.FlowSignal.check(params_l)
has_params_r = not ct.FlowSignal.check(params_r)
operation = props['operation']
if has_params_l and has_params_r and operation is not None:
return params_l | params_r
return ct.FlowSignal.FlowPending
if operation is not None:
return operation.transform_params(params_l, params_r)
return FS.FlowPending
####################

View File

@ -14,126 +14,234 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Declares `ReduceMathNode`."""
import enum
import typing as typ
import bpy
import jax
import jax.numpy as jnp
import sympy as sp
import numpy as np
from blender_maxwell.utils import logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from .... import contracts as ct
from .... import sockets
from .... import math_system, sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
RO = math_system.ReduceOperation
class ReduceMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
The shape, type, and interpretation of the input/output data is dynamically shown.
Attributes:
operation: Operation to apply to the input.
"""
node_type = ct.NodeType.ReduceMath
bl_label = 'Reduce Math'
input_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
'Axis': sockets.IntegerNumberSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'By Axis': {
'Axis': sockets.IntegerNumberSocketDef(),
},
'Expr': {
'Reducer': sockets.ExprSocketDef(
symbols=[sp.Symbol('a'), sp.Symbol('b')],
default_expr=sp.Symbol('a') + sp.Symbol('b'),
),
'Axis': sockets.IntegerNumberSocketDef(),
},
'Expr': sockets.ExprSocketDef(active_kind=FK.Func, show_func_ui=False),
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
# - Properties
# - Properties: Expr InfoFlow
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to reduce the input axis with',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.on_prop_changed('operation', context),
@events.on_value_changed(
# Trigger
socket_name={'Expr': FK.Info},
# Loaded
inscks_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr'},
# Flow
## -> Expr wants to emit DataChanged, which is usually fine.
## -> However, this node sets `expr_info`, which causes DC to emit.
## -> One action should emit one DataChanged pipe.
## -> Therefore, defer responsibility for DataChanged to self.expr_info.
# stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not FS.check(input_sockets['Expr'])
info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the input expression's `InfoFlow`."""
info = self._compute_input('Expr', kind=FK.Info)
has_info = not FS.check(info)
if has_info:
return info
return None
####################
# - Properties: Operation
####################
operation: RO = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Axis':
items += [
# Accumulation
('SUM', 'Sum', 'sum(*, N, *) -> (*, 1, *)'),
('PROD', 'Prod', 'prod(*, N, *) -> (*, 1, *)'),
('MIN', 'Axis-Min', '(*, N, *) -> (*, 1, *)'),
('MAX', 'Axis-Max', '(*, N, *) -> (*, 1, *)'),
('P2P', 'Peak-to-Peak', '(*, N, *) -> (*, 1 *)'),
# Stats
('MEAN', 'Mean', 'mean(*, N, *) -> (*, 1, *)'),
('MEDIAN', 'Median', 'median(*, N, *) -> (*, 1, *)'),
('STDDEV', 'Std Dev', 'stddev(*, N, *) -> (*, 1, *)'),
('VARIANCE', 'Variance', 'var(*, N, *) -> (*, 1, *)'),
# Dimension Reduction
('SQUEEZE', 'Squeeze', '(*, 1, *) -> (*, *)'),
]
else:
items += [('NONE', 'None', 'No operations...')]
def search_operations(self) -> list[ct.BLEnumElement]:
"""Retrieve valid operations based on the input `InfoFlow`."""
if self.expr_info is not None:
return RO.bl_enum_elements(self.expr_info)
return []
return items
####################
# - Properties: Dimension Selection
####################
active_dim: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
def search_dims(self) -> list[ct.BLEnumElement]:
"""Search valid dimensions for reduction."""
if self.expr_info is not None and self.operation is not None:
return [
(dim.name, dim.name_pretty, dim.name, '', i)
for i, dim in enumerate(self.operation.valid_dims(self.expr_info))
]
return []
@bl_cache.cached_bl_property(depends_on={'expr_info', 'active_dim'})
def dim(self) -> sim_symbols.SimSymbol | None:
"""Deduce the valid dimension."""
if self.expr_info is not None and self.active_dim is not None:
return self.expr_info.dim_by_name(self.active_dim, optional=True)
return None
@bl_cache.cached_bl_property(depends_on={'dim_0'})
def axis(self) -> int | None:
"""The first currently active axis, derived from `self.dim_0`."""
if self.expr_info is not None and self.dim is not None:
return self.expr_info.dim_axis(self.dim)
return None
####################
# - UI
####################
def draw_label(self):
"""Show the active reduce operation in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
if self.operation is not None:
if self.dim is not None:
return self.operation.name.replace('[a]', f'[{self.dim.name_pretty}]')
return self.operation.name
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
if self.active_socket_set != 'Axis Expr':
layout.prop(self, 'operation')
"""Draw the user interfaces of the node's properties inside of the node itself.
Parameters:
layout: UI target for drawing.
"""
layout.prop(self, self.blfields['operation'], text='')
layout.prop(self, self.blfields['active_dim'], text='')
####################
# - Compute
# - Compute: Array
####################
@events.computes_output_socket(
'Data',
props={'active_socket_set', 'operation'},
input_sockets={'Data', 'Axis', 'Reducer'},
input_socket_kinds={'Reducer': ct.FlowKind.Func},
input_sockets_optional={'Reducer': True},
'Expr',
kind=FK.Array,
# Loaded
outscks_kinds={
'Expr': {FK.Func, FK.Params},
},
)
def compute_data(self, props: dict, input_sockets: dict):
if props['active_socket_set'] == 'By Axis':
# Simple Accumulation
if props['operation'] == 'SUM':
return jnp.sum(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'PROD':
return jnp.prod(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MIN':
return jnp.min(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MAX':
return jnp.max(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'P2P':
return jnp.p2p(input_sockets['Data'], axis=input_sockets['Axis'])
def compute_array(self, output_sockets) -> ct.ArrayFlow | FS:
"""Realize an `ArrayFlow` containing the array."""
array = events.realize_known(output_sockets['Expr'])
if array is not None:
return ct.ArrayFlow(
jax_bytes=(
array
if isinstance(array, np.ndarray | jax.Array)
else jnp.array(array)
),
unit=output_sockets['Expr'][FK.Func].func_output.unit,
is_sorted=True,
)
return FS.FlowPending
# Stats
if props['operation'] == 'MEAN':
return jnp.mean(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'MEDIAN':
return jnp.median(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'STDDEV':
return jnp.std(input_sockets['Data'], axis=input_sockets['Axis'])
if props['operation'] == 'VARIANCE':
return jnp.var(input_sockets['Data'], axis=input_sockets['Axis'])
####################
# - Compute: Func
####################
@events.computes_output_socket(
'Expr',
kind=FK.Func,
# Loaded
props={'operation'},
inscks_kinds={
'Expr': FK.Func,
},
)
def compute_func(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Transform the input `FuncFlow` depending on the reduce operation."""
func = input_sockets['Expr']
# Dimension Reduction
if props['operation'] == 'SQUEEZE':
return jnp.squeeze(input_sockets['Data'], axis=input_sockets['Axis'])
operation = props['operation']
if operation is not None:
return operation.transform_func(func)
return FS.FlowPending
if props['active_socket_set'] == 'Expr':
ufunc = jnp.ufunc(input_sockets['Reducer'], nin=2, nout=1)
return ufunc.reduce(input_sockets['Data'], axis=input_sockets['Axis'])
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=FK.Info,
# Loaded
props={'operation', 'dim', 'expr_info'},
)
def compute_info(self, props) -> ct.InfoFlow | FS:
"""Transform the input `InfoFlow` depending on the reduce operation."""
info = props['expr_info']
dim = props['dim']
msg = 'Operation invalid'
raise ValueError(msg)
operation = props['operation']
if operation is not None:
return operation.transform_info(info, dim)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=FK.Params,
# Loaded
props={'operation', 'axis'},
inscks_kinds={'Expr': FK.Params},
)
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Transform the input `InfoFlow` depending on the reduce operation."""
params = input_sockets['Expr']
axis = props['axis']
operation = props['operation']
if operation is not None and axis is not None:
return operation.transform_params(params, axis)
return FS.FlowPending
####################

View File

@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
@ -49,10 +52,10 @@ class TransformMathNode(base.MaxwellSimNode):
bl_label = 'Transform Math'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@ -62,9 +65,8 @@ class TransformMathNode(base.MaxwellSimNode):
# Trigger
socket_name={'Expr'},
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
inscks_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr'},
# Flow
## -> Expr wants to emit DataChanged, which is usually fine.
## -> However, this node sets `expr_info`, which causes DC to emit.
@ -73,18 +75,16 @@ class TransformMathNode(base.MaxwellSimNode):
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
has_info = not FS.check(input_sockets['Expr'])
info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
info = self._compute_input('Expr', kind=FK.Info)
has_info = not FS.check(info)
if has_info:
return info
@ -100,12 +100,7 @@ class TransformMathNode(base.MaxwellSimNode):
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(
math_system.TransformOperation.by_info(self.expr_info)
)
]
return math_system.TransformOperation.bl_enum_elements(self.expr_info)
return []
####################
@ -283,29 +278,27 @@ class TransformMathNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'operation', 'dim'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info},
'Expr': {FK.Func, FK.Info},
},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Info},
output_socket_kinds={'Expr': FK.Info},
)
def compute_func(
self, props, input_sockets, output_sockets
) -> ct.FuncFlow | ct.FlowSignal:
def compute_func(self, props, input_sockets, output_sockets) -> ct.FuncFlow | FS:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = math_system.TransformOperation
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
lazy_func = input_sockets['Expr'][FK.Func]
info = input_sockets['Expr'][FK.Info]
output_info = output_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_output_info = not ct.FlowSignal.check(output_info)
has_info = not FS.check(info)
has_lazy_func = not FS.check(lazy_func)
has_output_info = not FS.check(output_info)
operation = props['operation']
if operation is not None and has_lazy_func and has_info and has_output_info:
@ -318,7 +311,7 @@ class TransformMathNode(base.MaxwellSimNode):
enclosing_func_output=output_info.output,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
case _:
return lazy_func.compose_within(
@ -327,30 +320,28 @@ class TransformMathNode(base.MaxwellSimNode):
supports_jax=True,
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={'operation', 'dim', 'new_name', 'new_unit', 'new_physical_type'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
input_socket_kinds={'Expr': {FK.Func, FK.Info, FK.Params}},
)
def compute_info( # noqa: PLR0911
self, props: dict, input_sockets: dict
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
) -> ct.InfoFlow | typ.Literal[FS.FlowPending]:
"""Transform the input `InfoFlow` depending on the transform operation."""
TO = math_system.TransformOperation
operation = props['operation']
info = input_sockets['Expr'][ct.FlowKind.Info]
info = input_sockets['Expr'][FK.Info]
has_info = not ct.FlowSignal.check(info)
has_info = not FS.check(info)
if has_info and operation is not None:
# Retrieve Properties
dim = props['dim']
@ -359,11 +350,11 @@ class TransformMathNode(base.MaxwellSimNode):
new_physical_type = props['new_physical_type']
# Retrieve Expression Data
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
params = input_sockets['Expr'][ct.FlowKind.Params]
lazy_func = input_sockets['Expr'][FK.Func]
params = input_sockets['Expr'][FK.Params]
has_lazy_func = not ct.FlowSignal.check(lazy_func)
has_params = not ct.FlowSignal.check(lazy_func)
has_lazy_func = not FS.check(lazy_func)
has_params = not FS.check(lazy_func)
# Match Pattern by Operation
match operation:
@ -376,7 +367,7 @@ class TransformMathNode(base.MaxwellSimNode):
and new_unit in physical_type.valid_units
):
return operation.transform_info(info, dim=dim, unit=new_unit)
return ct.FlowSignal.FlowPending
return FS.FlowPending
case TO.FreqToVacWL if dim is not None and new_unit is not None and new_unit in spux.PhysicalType.Length.valid_units:
return operation.transform_info(info, dim=dim, unit=new_unit)
@ -430,28 +421,28 @@ class TransformMathNode(base.MaxwellSimNode):
case TO.FT1D | TO.InvFT1D if dim is not None:
return operation.transform_info(info, dim=dim)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'operation'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params},
input_socket_kinds={'Expr': FK.Params},
)
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
operation = props['operation']
params = input_sockets['Expr']
has_params = not ct.FlowSignal.check(params)
has_params = not FS.check(params)
if has_params and operation is not None:
return params
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################

View File

@ -22,6 +22,7 @@ import jaxtyping as jtyp
import matplotlib.axis as mpl_ax
import sympy as sp
import sympy.physics.units as spu
from frozendict import frozendict
from blender_maxwell.utils import bl_cache, image_ops, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@ -32,6 +33,9 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class VizMode(enum.StrEnum):
"""Available visualization modes.
@ -218,7 +222,7 @@ class VizNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
active_kind=FK.Func,
default_symbols=[sym_x_um],
default_value=sp.exp(-(x_um**2)),
),
@ -239,26 +243,24 @@ class VizNode(base.MaxwellSimNode):
socket_name={'Expr'},
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_socket_kinds={'Expr': FK.Info},
input_sockets_optional={'Expr': True},
# Flow
## -> See docs in TransformMathNode
stop_propagation=True,
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
has_info = not FS.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
info_pending = FS.check_single(input_sockets['Expr'], FS.FlowPending)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info):
info = self._compute_input('Expr', kind=FK.Info)
if not FS.check(info):
return info
return None
@ -349,32 +351,28 @@ class VizNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
socket_name='Expr',
socket_name={'Expr': {FK.Info, FK.Params}},
run_on_init=True,
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Expr': True},
inscks_kinds={'Expr': {FK.Info, FK.Params}},
input_sockets_optional={'Expr'},
)
def on_any_changed(self, input_sockets: dict):
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
def on_expr_changed(self, input_sockets: dict):
"""Declare sockets for realizing all unknown symbols."""
info = input_sockets['Expr'][FK.Info]
params = input_sockets['Expr'][FK.Params]
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
if has_info and has_params and params.symbols:
if params.symbols:
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
'active_kind': ct.FlowKind.Range
if sym in info.dims
else ct.FlowKind.Value
'active_kind': FK.Value,
'use_value_range_swapper': sym in info.dims,
}
)
)
@ -389,25 +387,13 @@ class VizNode(base.MaxwellSimNode):
#####################
@events.computes_output_socket(
'Preview',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={
'sim_node_name',
'viz_mode',
'viz_target',
'colormap',
'plot_width',
'plot_height',
'plot_dpi',
},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
all_loose_input_sockets=True,
props={'sim_node_name', 'expr_info'},
)
def compute_previews(self, props, input_sockets, loose_input_sockets):
def compute_previews(self, props):
"""Needed for the plot to regenerate in the viewer."""
# info = props['info']
return ct.PreviewsFlow(bl_image_name=props['sim_node_name'])
#####################
@ -424,22 +410,19 @@ class VizNode(base.MaxwellSimNode):
'plot_dpi',
},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
},
input_socket_kinds={'Expr': {FK.Func, FK.Info, FK.Params}},
all_loose_input_sockets=True,
stop_propagation=True,
)
def on_show_plot(
self, managed_objs, props, input_sockets, loose_input_sockets
) -> None:
log.debug('Show Plot')
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
info = input_sockets['Expr'][ct.FlowKind.Info]
params = input_sockets['Expr'][ct.FlowKind.Params]
lazy_func = input_sockets['Expr'][FK.Func]
info = input_sockets['Expr'][FK.Info]
params = input_sockets['Expr'][FK.Params]
has_info = not ct.FlowSignal.check(info)
has_params = not ct.FlowSignal.check(params)
has_info = not FS.check(info)
has_params = not FS.check(params)
plot = managed_objs['plot']
viz_mode = props['viz_mode']
@ -452,11 +435,13 @@ class VizNode(base.MaxwellSimNode):
data = lazy_func.realize_as_data(
info,
params,
symbol_values={
sym: loose_input_sockets[sym.name] for sym in params.sorted_symbols
},
symbol_values=frozendict(
{
sym: loose_input_sockets[sym.name]
for sym in params.sorted_symbols
}
),
)
## TODO: CACHE entries that don't change, PLEASEEE
# Match Viz Type & Perform Visualization
## -> Viz Target determines how to plot.

View File

@ -27,9 +27,9 @@ from collections import defaultdict
from types import MappingProxyType
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache, bl_instance, logger
from blender_maxwell.utils import sympy_extra as spux
from .. import contracts as ct
from .. import managed_objs as _managed_objs
@ -39,6 +39,12 @@ from . import presets as _presets
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
FE = ct.FlowEvent
MT = spux.MathType
PT = spux.PhysicalType
####################
# - Types
####################
@ -82,7 +88,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_sets: typ.ClassVar[dict[str, Sockets]] = MappingProxyType({})
output_socket_sets: typ.ClassVar[dict[str, Sockets]] = MappingProxyType({})
managed_obj_types: typ.ClassVar[ManagedObjs] = MappingProxyType({})
managed_obj_types: typ.ClassVar[dict[str, ManagedObjs]] = MappingProxyType({})
presets: typ.ClassVar[dict[str, Preset]] = MappingProxyType({})
## __init_subclass__ Computed
@ -168,11 +174,12 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
event_methods = [
method
for attr_name in dir(cls)
if hasattr(method := getattr(cls, attr_name), 'event')
and method.event in set(ct.FlowEvent)
## Forbidding blfields prevents triggering __get__ on bl_property
if hasattr(method := getattr(cls, attr_name), 'identifier')
and isinstance(method.identifier, str)
and method.identifier == events.EVENT_METHOD_IDENTIFIER
## We must not trigger __get__ on any blfields here.
]
event_methods_by_event = {event: [] for event in set(ct.FlowEvent)}
event_methods_by_event = {event: [] for event in set(FE)}
for method in event_methods:
event_methods_by_event[method.event].append(method)
@ -216,16 +223,16 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# (Re)Construct Managed Objects
## -> Due to 'prev_name', the new MObjs will be renamed on construction
managed_objs = props['managed_objs']
managed_obj_types = props['managed_obj_types']
self.managed_objs = {
mobj_name: mobj_type(
self.sim_node_name,
self.sim_node_name + (f'_{i}' if i > 0 else ''),
prev_name=(
props['managed_objs'][mobj_name].name
if mobj_name in props['managed_objs']
else None
managed_objs[mobj_name].name if mobj_name in managed_objs else None
),
)
for mobj_name, mobj_type in props['managed_obj_types'].items()
for i, (mobj_name, mobj_type) in enumerate(managed_obj_types.items())
}
@events.on_value_changed(prop_name='active_socket_set')
@ -447,12 +454,18 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## Only Trigger: If method loads no socket data, it runs.
## Optional: If method optional-loads socket, it runs.
triggered_event_methods = [
event_method
for event_method in self.filtered_event_methods_by_event(
ct.FlowEvent.DataChanged, (active_sckname, None, None)
value_changed_method
for value_changed_method in self.event_methods_by_event[
FE.DataChanged
]
if value_changed_method.callback_info.should_run(
None,
active_sckname,
frozenset(FK),
active_sckname in self.loose_input_sockets,
)
if active_sckname
not in event_method.callback_info.must_load_sockets
and active_sckname
in value_changed_method.callback_info.optional_sockets_kinds
]
for event_method in triggered_event_methods:
event_method(self)
@ -464,6 +477,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
unit_system=...,
)
# Update Sockets
@ -511,6 +525,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
unit_system=...,
)
def _add_new_active_sockets(self):
@ -575,88 +590,59 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
if i != current_idx_of_socket:
self.outputs.move(current_idx_of_socket, i)
####################
# - Event Methods
####################
@property
def _event_method_filter_by_event(self) -> dict[ct.FlowEvent, typ.Callable]:
"""Compute a map of FlowEvents, to a function that filters its event methods.
The returned filter functions are hard-coded, and must always return a `bool`.
They may use attributes of `self`, always return `True` or `False`, or something different.
Notes:
This is an internal method; you probably want `self.filtered_event_methods_by_event`.
Returns:
The map of `ct.FlowEvent` to a function that can determine whether any `event_method` should be run.
"""
return {
ct.FlowEvent.EnableLock: lambda *_: True,
ct.FlowEvent.DisableLock: lambda *_: True,
ct.FlowEvent.DataChanged: lambda event_method, socket_name, prop_names, _: (
(
socket_name
and socket_name in event_method.callback_info.on_changed_sockets
)
or (
prop_names
and any(
prop_name in event_method.callback_info.on_changed_props
for prop_name in prop_names
)
)
or (
socket_name
and event_method.callback_info.on_any_changed_loose_input
and socket_name in self.loose_input_sockets
)
),
# Non-Triggered
ct.FlowEvent.OutputRequested: lambda output_socket_method,
output_socket_name,
_,
kind: (
kind == output_socket_method.callback_info.kind
and (
output_socket_name
== output_socket_method.callback_info.output_socket_name
)
),
ct.FlowEvent.ShowPlot: lambda *_: True,
}
def filtered_event_methods_by_event(
self,
event: ct.FlowEvent,
_filter: tuple[ct.SocketName, str],
) -> list[typ.Callable]:
"""Return all event methods that should run, given the context provided by `_filter`.
The inclusion decision is made by the internal property `self._event_method_filter_by_event`.
Returns:
All `event_method`s that should run, as callable objects (they can be run using `event_method(self)`).
"""
return [
event_method
for event_method in self.event_methods_by_event[event]
if self._event_method_filter_by_event[event](event_method, *_filter)
]
####################
# - Compute: Input Socket
####################
@bl_cache.keyed_cache(
exclude={'self', 'optional'},
encode={'unit_system'},
exclude={'self'},
)
def compute_prop(
self,
prop_name: ct.PropName,
unit_system: spux.UnitSystem | None = None,
) -> typ.Any:
"""Computes the data of a property, with relevant unit system scaling.
Some properties return a `sympy` expression, which needs to be conformed to some unit system before it can be used.
For these cases,
When no unit system is in use, the cache of `compute_prop` is a transparent layer on top of the `BLField` cache, taking no extra memory.
Warnings:
**MUST** be invalidated whenever a property changed.
If not, then "phantom" values will be produced.
Parameters:
prop_name: The name of the property to compute the value of.
unit_system: The unit system to convert it to, if any.
Returns:
The property value, possibly scaled to the unit system.
"""
# Retrieve Unit System and Property
if hasattr(self, prop_name):
prop_value = getattr(self, prop_name)
else:
msg = f'The node {self.sim_node_name} has no property {prop_name}.'
raise ValueError
if unit_system is not None:
if isinstance(prop_value, spux.SympyType):
return spux.scale_to_unit_system(prop_value)
msg = f'Cannot scale property {prop_name}={prop_value} (type={type(prop_value)} to a unit system, since it is not a sympy object (unit_system={unit_system})'
raise ValueError(msg)
return prop_value
@bl_cache.keyed_cache(
exclude={'self'},
)
def _compute_input(
self,
input_socket_name: ct.SocketName,
kind: ct.FlowKind = ct.FlowKind.Value,
unit_system: dict[ct.SocketType, sp.Expr] | None = None,
optional: bool = False,
kind: FK = FK.Value,
unit_system: spux.UnitSystem | None = None,
) -> typ.Any:
"""Computes the data of an input socket, following links if needed.
@ -667,15 +653,16 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_name: The name of the input socket to compute the value of.
It must be currently active.
kind: The data flow kind to compute.
unit_system: The unit system to scale the computed input socket to.
"""
bl_socket = self.inputs.get(input_socket_name)
if bl_socket is not None:
if bl_socket.instance_id:
if kind is ct.FlowKind.Previews:
if kind is FK.Previews:
return bl_socket.compute_data(kind=kind)
return (
ct.FlowKind.scale_to_unit_system(
FK.scale_to_unit_system(
kind,
bl_socket.compute_data(kind=kind),
unit_system,
@ -687,91 +674,107 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# No Socket Instance ID
## -> Indicates that socket_def.preinit() has not yet run.
## -> Anyone needing results will need to wait on preinit().
return ct.FlowSignal.FlowInitializing
return FS.FlowInitializing
if kind is ct.FlowKind.Previews:
if kind is FK.Previews:
return ct.PreviewsFlow()
return ct.FlowSignal.NoFlow
return FS.NoFlow
####################
# - Compute Event: Output Socket
####################
@bl_cache.keyed_cache(
exclude={'self', 'optional'},
exclude={'self'},
)
def compute_output(
self,
output_socket_name: ct.SocketName,
kind: ct.FlowKind = ct.FlowKind.Value,
optional: bool = False,
) -> typ.Any:
kind: FK = FK.Value,
unit_system: spux.UnitSystem | None = None,
) -> typ.Any | FS:
"""Computes the value of an output socket.
Parameters:
output_socket_name: The name declaring the output socket, for which this method computes the output.
kind: The FlowKind to use when computing the output socket value.
unit_system: The unit system to scale the computed output socket to.
Returns:
The value of the output socket, as computed by the dedicated method
registered using the `@computes_output_socket` decorator.
"""
# Previews: Aggregate All Input Sockets
## -> All PreviewsFlows on all input sockets are combined.
## -> Output Socket Methods can add additional PreviewsFlows.
if kind is ct.FlowKind.Previews:
input_previews = functools.reduce(
lambda a, b: a | b,
[
self._compute_input(
socket_name,
kind=ct.FlowKind.Previews,
)
for socket_name in [bl_socket.name for bl_socket in self.inputs]
],
ct.PreviewsFlow(),
)
# No Output Socket: No Flow
## -> All PreviewsFlows on all input sockets are combined.
## -> Output Socket Methods can add additional PreviewsFlows.
if self.outputs.get(output_socket_name) is None:
return ct.FlowSignal.NoFlow
output_socket_methods = self.filtered_event_methods_by_event(
ct.FlowEvent.OutputRequested,
(output_socket_name, None, kind),
log.debug(
'[%s] Computing Output (socket_name=%s, socket_kinds=%s, unit_system=%s)',
self.sim_node_name,
str(output_socket_name),
str(kind),
str(unit_system),
)
# Exactly One Output Socket Method
## -> All PreviewsFlows on all input sockets are combined.
## -> Output Socket Methods can add additional PreviewsFlows.
if len(output_socket_methods) == 1:
res = output_socket_methods[0](self)
bl_socket = self.outputs.get(output_socket_name)
if bl_socket is not None:
if bl_socket.instance_id:
# Previews: Computed Aggregated Input Sockets
## -> All sockets w/output get Previews from all inputs.
## -> The user can also assign certain sockets.
if kind is FK.Previews:
input_previews = functools.reduce(
lambda a, b: a | b,
[
self._compute_input(
socket_name,
kind=FK.Previews,
)
for socket_name in [
bl_socket.name for bl_socket in self.inputs
]
],
ct.PreviewsFlow(),
)
# Res is PreviewsFlow: Concatenate
## -> This will add the elements within the returned PreviewsFluw.
if kind is ct.FlowKind.Previews and not ct.FlowSignal.check(res):
input_previews |= res
# Retrieve Valid Output Socket Method
## -> We presume that there is exactly one method per socket|kind.
## -> We presume that there is exactly one method per socket|kind.
outsck_methods = [
method
for method in self.event_methods_by_event[FE.OutputRequested]
if method.callback_info.should_run(output_socket_name, kind)
]
if len(outsck_methods) != 1:
if kind is FK.Previews:
return input_previews
return FS.NoFlow
return res
outsck_method = outsck_methods[0]
# > One Output Socket Method: Error
if len(output_socket_methods) > 1:
msg = (
f'More than one method found for ({output_socket_name}, {kind.value!s}.'
)
raise RuntimeError(msg)
# Compute Flow w/Output Socket Method
flow = outsck_method(self)
has_flow = not FS.check(flow)
if kind is ct.FlowKind.Previews:
return input_previews
return ct.FlowSignal.NoFlow
if kind is FK.Previews:
if has_flow:
return input_previews | flow
return input_previews
# *: Compute Flow
## -> Perform unit-system scaling (maybe)
## -> Otherwise, return flow (even if FlowSignal).
if has_flow and unit_system is not None:
return kind.scale_to_unit_system(flow, unit_system)
return flow
# No Socket Instance ID
## -> Indicates that socket_def.preinit() has not yet run.
## -> Anyone needing results will need to wait on preinit().
return FS.FlowInitializing
return FS.NoFlow
####################
# - Plot
####################
def compute_plot(self):
plot_methods = self.filtered_event_methods_by_event(ct.FlowEvent.ShowPlot, ())
"""Run all `on_show_plot` event methods."""
plot_methods = self.event_methods_by_event[FE.ShowPlot]
for plot_method in plot_methods:
plot_method(self)
@ -782,8 +785,8 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
self,
method_info: events.InfoOutputRequested,
input_socket_name: ct.SocketName | None,
input_socket_kinds: set[ct.FlowKind] | None,
prop_names: set[str] | None,
input_socket_kinds: set[FK] | None,
prop_names: set[ct.PropName] | None,
) -> bool:
return (
prop_names is not None
@ -795,14 +798,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
input_socket_kinds is None
or (
isinstance(
_kind := method_info.depon_input_socket_kinds.get(
input_socket_name, ct.FlowKind.Value
_kind := method_info.depon_input_sockets_kinds.get(
input_socket_name, FK.Value
),
set,
)
and input_socket_kinds.intersection(_kind)
)
or _kind == ct.FlowKind.Value
or _kind == FK.Value
or _kind in input_socket_kinds
)
or (
@ -815,9 +818,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
@bl_cache.cached_bl_property()
def output_socket_invalidates(
self,
) -> dict[
tuple[ct.SocketName, ct.FlowKind], set[tuple[ct.SocketName, ct.FlowKind]]
]:
) -> dict[tuple[ct.SocketName, FK], set[tuple[ct.SocketName, FK]]]:
"""Deduce which output socket | `FlowKind` combos are altered in response to a given output socket | `FlowKind` combo.
Returns:
@ -830,9 +831,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# Iterate ALL Methods that Compute Output Sockets
## -> We call it the "altered method".
## -> Our approach will be to deduce what relies on it.
output_requested_methods = self.event_methods_by_event[
ct.FlowEvent.OutputRequested
]
output_requested_methods = self.event_methods_by_event[FE.OutputRequested]
for altered_method in output_requested_methods:
altered_info = altered_method.callback_info
altered_key = (altered_info.output_socket_name, altered_info.kind)
@ -859,12 +858,15 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
is_same_kind = (
altered_info.kind
is (
_kind := invalidated_info.depon_output_socket_kinds.get(
_kind := invalidated_info.depon_output_sockets_kinds.get(
altered_info.output_socket_name
)
)
or (isinstance(_kind, set) and altered_info.kind in _kind)
or altered_info.kind is ct.FlowKind.Value
or (
isinstance(_kind, set | frozenset)
and altered_info.kind in _kind
)
or altered_info.kind is FK.Value
)
# Check Success: Add Invalidated (name,kind) to Altered Set
@ -881,14 +883,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
def trigger_event(
self,
event: ct.FlowEvent,
event: FE,
socket_name: ct.SocketName | None = None,
socket_kinds: set[ct.FlowKind] | None = None,
prop_names: set[str] | None = None,
socket_kinds: set[FK] | None = None,
prop_names: set[ct.PropName] | None = None,
) -> None:
"""Recursively triggers events forwards or backwards along the node tree, allowing nodes in the update path to react.
Use `events` decorators to define methods that react to particular `ct.FlowEvent`s.
Use `events` decorators to define methods that react to particular `FE`s.
Notes:
This can be an unpredictably heavy function, depending on the node graph topology.
@ -901,6 +903,14 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
socket_name: The input socket that was altered, if any, in order to trigger this event.
pop_name: The property that was altered, if any, in order to trigger this event.
"""
if socket_kinds is not None:
socket_kinds = frozenset(socket_kinds)
if prop_names is not None:
prop_names = frozenset(prop_names)
## -> Track actual alterations per output socket|kind.
altered_outscks_kinds: dict[ct.SocketName, set[FK]] = defaultdict(set)
# log.debug(
# '[%s] [%s] Triggered (socket_name=%s, socket_kinds=%s, prop_names=%s)',
# self.sim_node_name,
@ -910,15 +920,11 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# str(prop_names),
# )
# Invalidate Caches on DataChanged
## -> socket_kinds MUST NOT be None
## -> Trigger direction is always 'forwards' for DataChanged
## -> Track which FlowKinds are actually altered per-output-socket.
altered_socket_kinds: dict[ct.SocketName, set[ct.FlowKind]] = defaultdict(set)
if event is ct.FlowEvent.DataChanged:
# Event: DataChanged
if event is FE.DataChanged:
in_sckname = socket_name
# Clear Input Socket Cache(s)
# Input Socket: Clear Cache(s)
## -> The input socket cache for each altered FlowKinds is cleared.
## -> Since it's non-persistent, it will be lazily re-filled.
if in_sckname is not None:
@ -935,80 +941,76 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
unit_system=...,
)
# Clear Output Socket Cache(s)
for output_socket_method in self.event_methods_by_event[
ct.FlowEvent.OutputRequested
]:
# Output Sockets: Clear Cache(s)
for output_socket_method in self.event_methods_by_event[FE.OutputRequested]:
# Determine Consequences of Changed (Socket|Kind) / Prop
## -> Each '@computes_output_socket' declares data to load.
## -> Compare what was changed to what each output socket needs.
## -> IF what is needed, was changed, THEN:
## --- The output socket needs recomputing.
## -> Ask if the altered data should cause it to reload.
method_info = output_socket_method.callback_info
if self._should_recompute_output_socket(
method_info, socket_name, socket_kinds, prop_names
if method_info.should_recompute(
prop_names,
socket_name,
socket_kinds,
socket_name in self.loose_input_sockets,
):
out_sckname = method_info.output_socket_name
out_kind = method_info.kind
# log.debug(
# '![%s] Clear Output Socket Cache (%s, %s)',
# self.sim_node_name,
# out_sckname,
# out_kind,
# )
self.compute_output.invalidate(
output_socket_name=out_sckname,
kind=out_kind,
unit_system=...,
)
altered_socket_kinds[out_sckname].add(out_kind)
altered_outscks_kinds[out_sckname].add(out_kind)
# Invalidate Dependent Output Sockets
## -> Other outscks may depend on the altered outsck.
## -> The property 'output_socket_invalidates' encodes this.
## -> The property 'output_socket_invalidates' encodes this.
# Recurse: Output-Output Dependencies
## -> Outscks may depend on each other.
## -> Pre-build an invalidation map per-node.
cleared_outscks_kinds = self.output_socket_invalidates.get(
(out_sckname, out_kind)
)
if cleared_outscks_kinds is not None:
if cleared_outscks_kinds:
for dep_out_sckname, dep_out_kind in cleared_outscks_kinds:
# log.debug(
# '!![%s] Clear Output Socket Cache (%s, %s)',
# self.sim_node_name,
# out_sckname,
# out_kind,
# )
self.compute_output.invalidate(
output_socket_name=dep_out_sckname,
kind=dep_out_kind,
unit_system=...,
)
altered_socket_kinds[dep_out_sckname].add(dep_out_kind)
altered_outscks_kinds[dep_out_sckname].add(dep_out_kind)
# Clear Output Socket Cache(s)
## -> We aggregate it manually, so it needs a special invl.
## -> See self.compute_output()
if socket_kinds is not None and ct.FlowKind.Previews in socket_kinds:
# Any Preview Change -> All Output Previews Regenerate
## -> All output sockets aggregate all input socket previews.
if socket_kinds is not None and FK.Previews in socket_kinds:
for out_sckname in self.outputs.keys(): # noqa: SIM118
self.compute_output.invalidate(
output_socket_name=out_sckname,
kind=ct.FlowKind.Previews,
kind=FK.Previews,
unit_system=...,
)
altered_socket_kinds[out_sckname].add(ct.FlowKind.Previews)
altered_outscks_kinds[out_sckname].add(FK.Previews)
# Run Triggered Event Methods
## -> A triggered event method may request to stop propagation.
stop_propagation = False
triggered_event_methods = self.filtered_event_methods_by_event(
event, (socket_name, prop_names, None)
)
for event_method in triggered_event_methods:
stop_propagation |= event_method.stop_propagation
# log.debug(
# '![%s] Running: %s',
# self.sim_node_name,
# str(event_method.callback_info),
# )
event_method(self)
# Run 'on_value_changed' Callbacks
## -> These event methods specifically respond to DataChanged.
stop_propagation = False
for value_changed_method in (
method
for method in self.event_methods_by_event[FE.DataChanged]
if method.callback_info.should_run(
prop_names,
socket_name,
socket_kinds,
socket_name in self.loose_input_sockets,
)
):
stop_propagation |= value_changed_method.callback_info.stop_propagation
value_changed_method(self)
if stop_propagation:
return
elif event is FE.EnableLock or event is FE.DisableLock:
for lock_method in self.event_methods_by_event[event]:
lock_method(self)
# Propagate Event
## -> If 'stop_propagation' was tripped, don't propagate.
@ -1016,36 +1018,35 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Each FlowEvent decides whether to flow forwards/backwards.
## -> The trigger chain goes node/socket/socket/node/socket/...
## -> Unlinked sockets naturally stop the propagation.
if not stop_propagation:
direc = ct.FlowEvent.flow_direction[event]
for bl_socket in self._bl_sockets(direc=direc):
# DataChanged: Propagate Altered SocketKinds
## -> Only altered FlowKinds for the socket will propagate.
## -> In this way, we guarantee no extraneous (noop) flow.
if event is ct.FlowEvent.DataChanged:
if bl_socket.name in altered_socket_kinds:
# log.debug(
# '![%s] [%s] Propagating (direction=%s, altered_socket_kinds=%s)',
# self.sim_node_name,
# event,
# direc,
# altered_socket_kinds[bl_socket.name],
# )
bl_socket.trigger_event(
event, socket_kinds=altered_socket_kinds[bl_socket.name]
)
## -> Otherwise, do nothing - guarantee no extraneous flow.
# Propagate Normally
else:
direc = FE.flow_direction[event]
for bl_socket in self._bl_sockets(direc=direc):
# DataChanged: Propagate Altered SocketKinds
## -> Only altered FlowKinds for the socket will propagate.
## -> In this way, we guarantee no extraneous (noop) flow.
if event is FE.DataChanged:
if bl_socket.name in altered_outscks_kinds:
# log.debug(
# '![%s] [%s] Propagating (direction=%s)',
# '![%s] [%s] Propagating (direction=%s, altered_socket_kinds=%s)',
# self.sim_node_name,
# event,
# direc,
# altered_socket_kinds[bl_socket.name],
# )
bl_socket.trigger_event(event)
bl_socket.trigger_event(
event, socket_kinds=altered_outscks_kinds[bl_socket.name]
)
## -> Otherwise, do nothing - guarantee no extraneous flow.
# Propagate Normally
else:
# log.debug(
# '![%s] [%s] Propagating (direction=%s)',
# self.sim_node_name,
# event,
# direc,
# )
bl_socket.trigger_event(event)
####################
# - Property Event: On Update
@ -1068,13 +1069,19 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
if prop_name in self.blfields:
cleared_blfields = self.clear_blfields_after(prop_name)
for prop_name, _ in cleared_blfields:
self.compute_prop.invalidate(
prop_name=prop_name,
unit_system=...,
)
# log.debug(
# '%s (Node): Set of Cleared BLFields: %s',
# self.bl_label,
# str(cleared_blfields),
# )
self.trigger_event(
ct.FlowEvent.DataChanged,
FE.DataChanged,
prop_names={prop_name for prop_name, _ in cleared_blfields},
)
@ -1204,7 +1211,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Compromise: Users explicitly say 'run_on_init' in @on_value_changed
for event_method in [
event_method
for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
for event_method in self.event_methods_by_event[FE.DataChanged]
if event_method.callback_info.run_on_init
]:
event_method(self)
@ -1244,7 +1251,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Copying a node ~ re-initializing the new node.
for event_method in [
event_method
for event_method in self.event_methods_by_event[ct.FlowEvent.DataChanged]
for event_method in self.event_methods_by_event[FE.DataChanged]
if event_method.callback_info.run_on_init
]:
event_method(self)
@ -1271,7 +1278,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
bl_socket.is_linked and bl_socket.locked
for bl_socket in self.inputs.values()
):
self.trigger_event(ct.FlowEvent.DisableLock)
self.trigger_event(FE.DisableLock)
# Free Managed Objects
for managed_obj in self.managed_objs.values():

View File

@ -14,82 +14,439 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import dataclasses
import functools
import inspect
import typing as typ
import uuid
from collections import defaultdict
from fractions import Fraction
from types import MappingProxyType
from blender_maxwell.utils import sympy_extra as spux
import bpy
import jax
import numpy as np
import pydantic as pyd
import sympy as sp
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import FrozenDict, frozendict
from blender_maxwell.utils.lru_method import method_lru
from .. import contracts as ct
log = logger.get(__name__)
ManagedObjName: typ.TypeAlias = str
UnitSystemID = str
FK = ct.FlowKind
FS = ct.FlowSignal
EVENT_METHOD_IDENTIFIER: str = str(uuid.uuid4()) ## Changes on startup.
####################
# - Event Callback Information
####################
@dataclasses.dataclass(kw_only=True, frozen=True)
class InfoDataChanged:
class CallbackInfo(pyd.BaseModel):
"""Base for information associated with an event method callback."""
model_config = pyd.ConfigDict(frozen=True)
# def parse_for_method(self, node):
class InfoDataChanged(CallbackInfo):
"""Information for determining whether a particular `DataChanged` callback should run."""
model_config = pyd.ConfigDict(frozen=True)
on_changed_props: frozenset[ct.PropName]
on_changed_sockets_kinds: FrozenDict[ct.SocketName, frozenset[FK]]
on_any_changed_loose_socket: bool
run_on_init: bool
on_changed_sockets: set[ct.SocketName]
on_changed_props: set[str]
on_any_changed_loose_input: set[str]
must_load_sockets: set[str]
optional_sockets: frozenset[ct.SocketName]
stop_propagation: bool = False
####################
# - Computed Properties
####################
@functools.cached_property
def on_changed_sockets(self) -> frozenset[ct.SocketName]:
"""Input sockets with a `FlowKind` for which the method will run."""
return frozenset(self.on_changed_sockets_kinds.keys())
@functools.cached_property
def optional_sockets_kinds(self) -> frozendict[ct.SocketName, frozenset[FK]]:
"""Input `socket|kind`s for which the method can run, even when only a `FlowSignal` is available."""
return {
changed_socket: kinds
for changed_socket, kinds in self.on_changed_sockets_kinds
if changed_socket in self.optional_sockets
}
####################
# - Methods
####################
@method_lru(maxsize=2048)
def should_run(
self,
changed_props: frozenset[str] | None,
changed_socket: ct.SocketName | None,
changed_kinds: frozenset[FK] | None = None,
socket_is_loose: bool = False,
):
"""Deduce whether this method should run in response to a particular set of changed inputs."""
prop_triggered = changed_props is not None and any(
changed_prop in self.on_changed_props for changed_prop in changed_props
)
socket_triggered = (
changed_socket is not None
and changed_kinds is not None
and changed_socket in self.on_changed_sockets
and any(
changed_kind in self.on_changed_sockets_kinds[changed_socket]
for changed_kind in changed_kinds
)
)
loose_socket_triggered = (
socket_is_loose
and changed_socket is not None
and self.on_any_changed_loose_socket
)
return socket_triggered or prop_triggered or loose_socket_triggered
@dataclasses.dataclass(kw_only=True, frozen=True)
class InfoOutputRequested:
class InfoOutputRequested(CallbackInfo):
"""Information for determining which output socket method should run."""
model_config = pyd.ConfigDict(frozen=True)
output_socket_name: ct.SocketName
kind: ct.FlowKind
kind: FK
depon_props: set[str]
depon_input_sockets: set[ct.SocketName]
depon_input_socket_kinds: dict[ct.SocketName, ct.FlowKind | set[ct.FlowKind]]
depon_props: frozenset[str]
depon_input_sockets_kinds: FrozenDict[ct.SocketName, FK | frozenset[FK]]
depon_output_sockets_kinds: FrozenDict[ct.SocketName, FK | frozenset[FK]]
depon_all_loose_input_sockets: bool
depon_output_sockets: set[ct.SocketName]
depon_output_socket_kinds: dict[ct.SocketName, ct.FlowKind | set[ct.FlowKind]]
depon_all_loose_output_sockets: bool
####################
# - Computed Properties
####################
@functools.cached_property
def depon_input_sockets(self) -> frozenset[ct.SocketName]:
"""The input sockets depended on by this output socket method."""
return frozenset(self.depon_input_sockets_kinds.keys())
@functools.cached_property
def depon_output_sockets(self) -> frozenset[ct.SocketName]:
"""The output sockets depended on by this output socket method."""
return frozenset(self.depon_output_sockets_kinds.keys())
####################
# - Methods
####################
@method_lru(maxsize=2048)
def should_run(
self,
requested_socket: ct.SocketName,
requested_kind: FK,
):
"""Deduce whether this method can compute the requested socket and kind."""
return (
requested_kind is self.kind and requested_socket == self.output_socket_name
)
@method_lru(maxsize=2048)
def should_recompute(
self,
changed_props: frozenset[str] | None,
changed_socket: ct.SocketName | None,
changed_kinds: frozenset[FK] | None = None,
socket_is_loose: bool = False,
):
"""Deduce whether this method needs to be recomputed after a change in a particular set of changed inputs."""
prop_altered = changed_props is not None and any(
changed_prop in self.depon_props for changed_prop in changed_props
)
socket_altered = (
changed_socket is not None
and changed_kinds is not None
and changed_socket in self.depon_input_sockets
and any(
kind in self.depon_input_sockets_kinds[changed_socket]
for kind in changed_kinds
)
)
loose_socket_altered = (
socket_is_loose
and changed_socket is not None
and self.depon_all_loose_input_sockets
)
return prop_altered or socket_altered or loose_socket_altered
EventCallbackInfo: typ.TypeAlias = InfoDataChanged | InfoOutputRequested
####################
# - Event Decorator
# - Node Parsers
####################
ManagedObjName: typ.TypeAlias = str
PropName: typ.TypeAlias = str
def parse_node_mobjs(node: bpy.types.Node, mobjs: frozenset[ManagedObjName]) -> typ.Any:
"""Retrieve the given managed objects."""
return {mobj_name: node.managed_objs[mobj_name] for mobj_name in mobjs}
def event_decorator( # noqa: PLR0913
def parse_node_props(
node: bpy.types.Node,
props: frozenset[ct.PropName],
prop_unit_systems: frozendict[ct.PropName, spux.UnitSystem | None],
) -> typ.Any:
"""Compute the values of the given property names, w/optional scaling to a unit system.
Raises:
ValueError: If a unit system is specified for the property value, but the property value is not a `sympy` type.
"""
return frozendict(
{
prop: node.compute_prop(prop, unit_system=prop_unit_systems.get(prop))
for prop in props
}
)
def parse_node_sck(
node: bpy.types.Node,
direc: typ.Literal['input', 'output'],
sck: ct.SocketName,
_kind: FK | None,
unit_system: spux.UnitSystem | None,
) -> typ.Any:
"""Compute a single value for `sck|kind|unit_system`.
Parameters:
node: The node to parse a socket value from.
direc: Whether the socket to parse is an input or output socket.
sck: The name of the socket to parse.
_kind: The `FlowKind` of the socket to parse.
When `None`, use the `.active_kind` attribute of the socket to determine what to parse.
unit_system: The unit system with which to scale the socket value.
Returns:
The value of the socket over the given `FlowKind` lane, potentially scaled to a unitless scalar (if requested) in a manner specific to the `FlowKind`.
"""
# Deduce Kind
## -> _kind=None denotes "use active_kind of socket".
kind = node._bl_sockets(direc=direc)[sck].active_kind if _kind is None else _kind # noqa: SLF001
# Compute Socket Value
if direc == 'input':
return node._compute_input(sck, kind=kind, unit_system=unit_system) # noqa: SLF001
if direc == 'output':
return node.compute_output(sck, kind=kind, unit_system=unit_system)
raise TypeError
def parse_node_scks_kinds(
node: bpy.types.Node,
direc: typ.Literal['input', 'output'],
scks_kinds: frozendict[ct.SocketName, frozenset[FK] | FK | None],
scks_unit_system: spux.UnitSystem
| frozendict[ct.SocketName, spux.UnitSystem | None],
) -> (
frozendict[ct.SocketName, typ.Any]
| frozendict[ct.SocketName, frozenset[typ.Any]]
| None
):
"""Retrieve the values for given input sockets and kinds, w/optional scaling to a unit system.
In general, unless the socket name is specified in `scks_optional`, then whenever `FlowSignal` is encountered while computing sockets, the function will return `None` immediately.
This process is "short-circuit", which is a partial optimization causing an immediate return before any other computing any other sockets.
"""
# Compute Socket Values
## -> Every time we run `compute()`, we might encounter a FlowSignal.
## -> If so, and the socket is 'optional', we let it be.
## -> If not, and the socket is not 'optional', we return immediately.
computed_scks = {}
for sck, kinds in scks_kinds.items():
# Extract Unit System
if isinstance(scks_unit_system, dict | frozendict):
unit_system = scks_unit_system.get(sck)
else:
unit_system = scks_unit_system
flow_values = {}
for kind in kinds if isinstance(kinds, set | frozenset) else {kinds}:
flow = parse_node_sck(node, direc, sck, kind, unit_system)
flow_values[kind] = flow
if len(flow_values) == 1:
computed_scks[sck] = next(iter(flow_values.values()))
else:
computed_scks[sck] = frozendict(flow_values)
return frozendict(computed_scks)
####################
# - Utilities
####################
def freeze_pytree(
pytree: None
| str
| int
| Fraction
| float
| complex
| set
| frozenset
| list
| tuple
| dict
| frozendict,
):
"""Conform an arbitrarily nested containers into their immutable equivalent."""
if isinstance(pytree, set | frozenset):
return frozenset(freeze_pytree(el) for el in pytree)
if isinstance(pytree, list | tuple):
return tuple(freeze_pytree(el) for el in pytree)
if isinstance(pytree, np.ndarray | jax.Array):
return tuple(freeze_pytree(el) for el in pytree.tolist())
if isinstance(pytree, dict | frozendict):
return frozendict({k: freeze_pytree(v) for k, v in pytree.items()})
if isinstance(pytree, None | str | int | Fraction | float | complex):
return pytree
raise TypeError
def realize_known(
sck: frozendict[typ.Literal[FK.Func, FK.Params], ct.FuncFlow | ct.ParamsFlow],
freeze: bool = False,
conformed: bool = False,
) -> int | float | tuple[int | float] | None:
"""Realize a concrete preview-value from a `FuncFlow` and a `ParamsFlow`, when there are no unrealized symbols in `ParamsFlow`.
It often happens that we absolutely need an _unrealized_ value for a node to work, eg. when producing a `Value` from the middle of a fully-lazy chain.
Several complications arise when doing so, not least of which is how to handle the case where there are still-unrealized symbols.
This method encapsulates all of these complexities into a single call, whose availability can be handled with a simple `None`-check.
Parameters:
sck: Mapping from dictionaries
Examples:
Within an event method depending on a socket `Socket: {FK.Func, FK.Params, ...}`, a realized value can
Generally accessed by calling `event.realize_known
Returns `None` when there are unrealized symbols, or either `Func` or `Params` is a `FlowSignal`.
"""
has_func = not FS.check(sck[FK.Func])
has_params = not FS.check(sck[FK.Params])
if has_func and has_params:
realized = sck[FK.Func].realize(sck[FK.Params])
if freeze:
return freeze_pytree(realized)
if conformed:
func_output = sck[FK.Func].func_output
if func_output is not None:
return func_output.conform(realized)
return realized
return realized
return None
def realize_preview(
sck: frozendict[typ.Literal[FK.Func, FK.Params], ct.FuncFlow | ct.ParamsFlow],
) -> int | float | tuple[int | float]:
"""Realize a concrete preview-value from a `FuncFlow` and a `ParamsFlow`.
This particular operation is widely used in `on_value_changed` methods that update previews, since they must intercept both the function and parameter flows in order to respect ex. partially relized symbols and units.
Usually, when such a thing happens, we support it in `event_decorator`.
But in this case, the designs required were not resonating quite right - adding either too much specific complexity to input/output/loose fields, requiring too much "magic", etc. .
Why not just ask users to intercept intercept the `Value` output?
Several reasons.
Firstly, it alone _absolutely cannot_ handle unrealized symbols, which while reasonable for a structure that should be **fully realized**, is not quite the desired functionality with preview-oriented workflows: In particular, we want we want to use the the stand-in `SimSymbol.preview_value_phy`, even though it's not always "accurate".
Secondly, constructing the full `Value` output is slow, and introduces superfluous preview-dependencies.
In this situation, the best balance is to provide this utility function.
The user needs to deconstruct the eg. `input_sockets` parameter _anyway_ at least once; with this function, what would otherwise be an unweildy piece of realization logic is cleanly encapsulated.
"""
return sck[FK.Func].realize_preview(sck[FK.Params])
def mk_sockets_kinds(
sockets: dict[ct.SocketName, frozenset[FK]]
| dict[ct.SocketName, FK]
| ct.SocketName
| None,
default_kinds: frozenset[FK] = frozenset(FK),
) -> frozendict[ct.SocketName, frozenset[FK]]:
"""Normalize the given parameters to a standardized type."""
# Deduce Triggered Socket -> SocketKinds
## -> Normalize all valid inputs to frozendict[SocketName, set[FlowKind]].
if sockets is None:
return {}
if isinstance(sockets, dict | frozendict):
sockets_kinds = {
socket: (
kinds if isinstance(kinds := _kinds, set | frozenset) else {_kinds}
)
for socket, _kinds in sockets.items()
}
else:
sockets_kinds = {
socket: default_kinds
for socket in (
sockets if isinstance(sockets, set | frozenset) else [sockets]
)
}
return frozendict(sockets_kinds)
####################
# - General Event Callbacks
####################
def event_decorator( # noqa: C901, PLR0913, PLR0915
event: ct.FlowEvent,
callback_info: EventCallbackInfo | None,
stop_propagation: bool = False,
# Request Data for Callback
managed_objs: set[ManagedObjName] = frozenset(),
props: set[PropName] = frozenset(),
input_sockets: set[ct.SocketName] = frozenset(),
input_sockets_optional: dict[ct.SocketName, bool] = MappingProxyType({}),
input_socket_kinds: dict[
ct.SocketName, ct.FlowKind | set[ct.FlowKind]
] = MappingProxyType({}),
output_sockets: set[ct.SocketName] = frozenset(),
output_sockets_optional: dict[ct.SocketName, bool] = MappingProxyType({}),
output_socket_kinds: dict[
ct.SocketName, ct.FlowKind | set[ct.FlowKind]
] = MappingProxyType({}),
# Loading: Internal Data
managed_objs: frozenset[ManagedObjName] = frozenset(),
props: frozenset[ct.PropName] = frozenset(),
# Loading: Input Sockets
input_sockets: frozenset[ct.SocketName] = frozenset(),
input_socket_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] = frozendict(),
inscks_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] | None = None,
input_sockets_optional: frozenset[ct.SocketName] = frozendict(),
# Loading: Output Sockets
output_sockets: frozenset[ct.SocketName] = frozenset(),
output_socket_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] = frozendict(),
outscks_kinds: frozendict[ct.SocketName, FK | frozenset[FK]] | None = None,
output_sockets_optional: frozenset[ct.SocketName] = frozenset(),
# Loading: Loose Sockets
all_loose_input_sockets: bool = False,
loose_input_sockets_kind: frozenset[FK] | FK | None = None,
all_loose_output_sockets: bool = False,
# Request Unit System Scaling
unit_systems: dict[UnitSystemID, spux.UnitSystem] = MappingProxyType({}),
scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
loose_output_sockets_kind: frozenset[FK] | FK | None = None,
# Loading: Unit System Scaling
scale_props: frozendict[ct.PropName, spux.UnitSystem] = frozendict(),
scale_input_sockets: frozendict[ct.SocketName, spux.UnitSystem] | None = None,
scale_output_sockets: frozendict[ct.SocketName, spux.UnitSystem] | None = None,
scale_loose_input_sockets: spux.UnitSystem | None = None,
scale_loose_output_sockets: spux.UnitSystem | None = None,
):
"""Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
@ -101,25 +458,22 @@ def event_decorator( # noqa: PLR0913
event: A name describing which event the decorator should respond to.
callback_info: A dictionary that provides the caller with additional per-`event` information.
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
Other methods defined on the same node will still run.
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
props: Set of `props` to compute, then pass to the decorated method.
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
input_sockets_optional: Whether an input socket is required to exist.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
input_sockets_optional: Allow the method will run even if one of these input socket values are a `ct.FlowSignal`.
input_socket_kinds: The `FK` to compute per-input-socket.
If an input socket isn't specified, it defaults to `FK.Value`.
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
output_sockets_optional: Whether an output socket is required to exist.
output_sockets_optional: Allow the method will run even if one of these output socket values are a `ct.FlowSignal`.
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
output_socket_kinds: The `FK` to compute per-output-socket.
If an output socket isn't specified, it defaults to `FK.Value`.
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
Used when the names of the loose input sockets are unknown, but all of their values are needed.
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
Used when the names of the loose output sockets are unknown, but all of their values are needed.
unit_systems: String identifiers under which to load a unit system, made available to the method.
scale_props: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
@ -130,204 +484,152 @@ def event_decorator( # noqa: PLR0913
"""
req_params = (
{'self'}
| ({'props'} if props else set())
| ({'managed_objs'} if managed_objs else set())
| ({'input_sockets'} if input_sockets else set())
| ({'output_sockets'} if output_sockets else set())
| ({'loose_input_sockets'} if all_loose_input_sockets else set())
| ({'loose_output_sockets'} if all_loose_output_sockets else set())
| ({'unit_systems'} if unit_systems else set())
| ({'props'} if props else frozenset())
| ({'managed_objs'} if managed_objs else frozenset())
| ({'input_sockets'} if input_sockets or inscks_kinds else frozenset())
| ({'output_sockets'} if output_sockets or outscks_kinds else frozenset())
| ({'loose_input_sockets'} if all_loose_input_sockets else frozenset())
| ({'loose_output_sockets'} if all_loose_output_sockets else frozenset())
)
# TODO: Check that all Unit System IDs referenced are also defined in 'unit_systems'.
# Simplify I/O Under Naming
if inscks_kinds is None:
inscks_kinds = frozendict(
{
socket: input_socket_kinds.get(socket, FK.Value)
for socket in input_sockets
}
)
inscks_unit_system = (
frozendict(scale_input_sockets) if scale_input_sockets is not None else None
)
inscks_optional = frozenset(input_sockets_optional)
if outscks_kinds is None:
outscks_kinds = frozendict(
{
output_socket: output_socket_kinds.get(output_socket, FK.Value)
for output_socket in output_sockets
}
)
outscks_unit_system = (
frozendict(scale_output_sockets) if scale_output_sockets is not None else None
)
outscks_optional = frozenset(output_sockets_optional)
## TODO: More ex. introspective checks and such, to make it really hard to write invalid methods.
# TODO: Check Function Annotation Validity
## - socket capabilities
def decorator(method: typ.Callable) -> typ.Callable:
def decorator(method: typ.Callable) -> typ.Callable: # noqa: C901
# Check Function Signature Validity
func_sig = set(inspect.signature(method).parameters.keys())
func_sig = frozenset(inspect.signature(method).parameters.keys())
## Too Few Arguments
## -> Too Few Arguments
if func_sig != req_params and func_sig.issubset(req_params):
msg = f'Decorated method {method.__name__} is missing arguments {req_params - func_sig}'
## Too Many Arguments
## -> Too Many Arguments
if func_sig != req_params and func_sig.issuperset(req_params):
msg = f'Decorated method {method.__name__} has superfluous arguments {func_sig - req_params}'
raise ValueError(msg)
def decorated(node):
method_kw_args = {} ## Keyword Arguments for Decorated Method
# Unit Systems
method_kw_args |= {'unit_systems': unit_systems} if unit_systems else {}
# Properties
method_kw_args |= (
{'props': {prop_name: getattr(node, prop_name) for prop_name in props}}
if props
else {}
)
def decorated(node: bpy.types.Node): # noqa: C901, PLR0912
method_kwargs = defaultdict(dict)
# Managed Objects
method_kw_args |= (
{
'managed_objs': {
managed_obj_name: node.managed_objs[managed_obj_name]
for managed_obj_name in managed_objs
}
}
if managed_objs
else {}
)
if managed_objs:
method_kwargs['managed_objs'] = {}
for mobj_name, mobj in parse_node_mobjs(node, managed_objs).items():
method_kwargs['managed_objs'] |= {mobj_name: mobj}
# Properties
if props:
method_kwargs['props'] = {}
for prop, value in parse_node_props(node, props, scale_props).items():
method_kwargs['props'] |= {prop: value}
# Sockets
## Input Sockets
method_kw_args |= (
{
'input_sockets': {
input_socket_name: node._compute_input(
input_socket_name,
kind=_kind,
unit_system=(
unit_systems.get(
scale_input_sockets.get(input_socket_name)
)
),
optional=input_sockets_optional.get(
input_socket_name, False
),
)
if not isinstance(
_kind := input_socket_kinds.get(
input_socket_name, ct.FlowKind.Value
),
set,
)
else {
kind: node._compute_input(
input_socket_name,
kind=kind,
unit_system=unit_systems.get(
scale_input_sockets.get(input_socket_name)
),
optional=input_sockets_optional.get(
input_socket_name, False
),
)
for kind in _kind
}
for input_socket_name in input_sockets
}
}
if input_sockets
else {}
)
if inscks_kinds:
method_kwargs['input_sockets'] = {}
for insck, flow in parse_node_scks_kinds(
node,
'input',
inscks_kinds,
inscks_unit_system,
).items():
has_flow = not FS.check(flow)
if has_flow or insck in inscks_optional:
method_kwargs['input_sockets'] |= {insck: flow}
else:
flow_signal = flow
return flow_signal # noqa: RET504
## Output Sockets
def _g_output_socket(output_socket_name: ct.SocketName, kind: ct.FlowKind):
if scale_output_sockets.get(output_socket_name) is None:
return node.compute_output(
output_socket_name,
kind=kind,
optional=output_sockets_optional.get(output_socket_name, False),
)
return ct.FlowKind.scale_to_unit_system(
kind,
node.compute_output(
output_socket_name,
kind=kind,
optional=output_sockets_optional.get(output_socket_name, False),
),
unit_systems.get(scale_output_sockets.get(output_socket_name)),
)
method_kw_args |= (
{
'output_sockets': {
output_socket_name: _g_output_socket(output_socket_name, _kind)
if not isinstance(
_kind := output_socket_kinds.get(
output_socket_name, ct.FlowKind.Value
),
set,
)
else {
kind: _g_output_socket(output_socket_name, kind)
for kind in _kind
}
for output_socket_name in output_sockets
}
}
if output_sockets
else {}
)
if outscks_kinds:
method_kwargs['output_sockets'] = {}
for outsck, flow in parse_node_scks_kinds(
node,
'output',
outscks_kinds,
outscks_unit_system,
).items():
has_flow = not FS.check(flow)
if has_flow or outsck in outscks_optional:
method_kwargs['output_sockets'] |= {outsck: flow}
else:
flow_signal = flow
return flow_signal # noqa: RET504
# Loose Sockets
## -> Determined by the active_kind of each loose input socket.
method_kw_args |= (
{
'loose_input_sockets': {
input_socket_name: node._compute_input(
input_socket_name,
kind=node.inputs[input_socket_name].active_kind,
)
for input_socket_name in node.loose_input_sockets
}
if all_loose_input_sockets:
method_kwargs['loose_input_sockets'] = {}
loose_inscks = frozenset(node.loose_input_sockets.keys())
loose_inscks_kinds = {
loose_insck: loose_input_sockets_kind
for loose_insck in loose_inscks
}
if all_loose_input_sockets
else {}
)
## Compute All Loose Output Sockets
method_kw_args |= (
{
'loose_output_sockets': {
output_socket_name: node.compute_output(
output_socket_name,
kind=node.outputs[output_socket_name].active_kind,
)
for output_socket_name in node.loose_output_sockets
}
for loose_insck, flow in parse_node_scks_kinds(
node,
'input',
loose_inscks_kinds,
scale_loose_input_sockets,
).items():
method_kwargs['loose_input_sockets'] |= {loose_insck: flow}
if all_loose_output_sockets:
method_kwargs['loose_output_sockets'] = {}
loose_outscks = frozenset(node.loose_output_sockets.keys())
loose_outscks_kinds = {
loose_outsck: loose_output_sockets_kind
for loose_outsck in loose_outscks
}
if all_loose_output_sockets
else {}
)
# Propagate Initialization
## If there is a FlowInitializing, then the method would fail.
## Therefore, propagate FlowInitializing if found.
if any(
ct.FlowSignal.check_single(value, ct.FlowSignal.FlowInitializing)
for sockets in [
method_kw_args.get('input_sockets', {}),
method_kw_args.get('loose_input_sockets', {}),
method_kw_args.get('output_sockets', {}),
method_kw_args.get('loose_output_sockets', {}),
]
for value in sockets.values()
):
return ct.FlowSignal.FlowInitializing
for loose_outsck, flow in parse_node_scks_kinds(
node,
'input',
loose_outscks_kinds,
scale_loose_output_sockets,
).items():
method_kwargs['loose_output_sockets'] |= {loose_outsck: flow}
# Call Method
return method(
node,
**method_kw_args,
**method_kwargs,
)
# Set Decorated Attributes and Return
## TODO: Fix Introspection + Documentation
# decorated.__name__ = method.__name__
# Wrap Decorated Method Attributes
## -> We can't just @wraps(), since the call signature changed.
decorated.__name__ = method.__name__
# decorated.__module__ = method.__module__
# decorated.__qualname__ = method.__qualname__
decorated.__doc__ = method.__doc__
## Add Spice
# Add Spice
decorated.identifier = EVENT_METHOD_IDENTIFIER
decorated.event = event
decorated.callback_info = callback_info
decorated.stop_propagation = stop_propagation
return decorated
@ -335,11 +637,12 @@ def event_decorator( # noqa: PLR0913
####################
# - Simplified Event Callbacks
# - Specific Event Callbacks
####################
def on_enable_lock(
**kwargs,
):
"""Declare a method that reacts to the enabling of the interface lock."""
return event_decorator(
event=ct.FlowEvent.EnableLock,
callback_info=None,
@ -350,6 +653,7 @@ def on_enable_lock(
def on_disable_lock(
**kwargs,
):
"""Declare a method that reacts to the disabling of the interface lock."""
return event_decorator(
event=ct.FlowEvent.DisableLock,
callback_info=None,
@ -359,28 +663,50 @@ def on_disable_lock(
## TODO: Consider changing socket_name and prop_name to more obvious names.
def on_value_changed(
socket_name: set[ct.SocketName] | ct.SocketName | None = None,
prop_name: set[str] | str | None = None,
prop_name: frozenset[str] | str | None = None,
socket_name: frozendict[ct.SocketName, frozenset[FK]]
| frozendict[ct.SocketName, FK]
| frozenset[ct.SocketName]
| ct.SocketName
| None = None,
socket_name_kinds: frozenset[FK] | FK = frozenset(FK),
any_loose_input_socket: bool = False,
run_on_init: bool = False,
stop_propagation: bool = False,
**kwargs,
):
"""Declare a method that reacts to a change in node data.
Can be configured to react to changes in:
- Any particular input `socket|kind`s.
- Any `loose_socket|active_kind`.
- Any particular property.
In addition, the method can be configured to run when its node is created and initialized for the first time.
"""
return event_decorator(
event=ct.FlowEvent.DataChanged,
callback_info=InfoDataChanged(
# Triggers
run_on_init=run_on_init,
on_changed_sockets=(
socket_name if isinstance(socket_name, set) else {socket_name}
# Trigger: Props
on_changed_props=(
prop_name
if isinstance(prop_name, set | frozenset)
else ({prop_name} if prop_name is not None else set())
),
on_changed_props=(prop_name if isinstance(prop_name, set) else {prop_name}),
on_any_changed_loose_input=any_loose_input_socket,
# Loaded
must_load_sockets={
# Trigger: Sockets
on_changed_sockets_kinds=mk_sockets_kinds(socket_name, socket_name_kinds),
# Trigger: Loose Sockets
on_any_changed_loose_socket=any_loose_input_socket,
# Trigger: Init
run_on_init=run_on_init,
# Hints
optional_sockets={
socket_name
for socket_name in kwargs.get('input_sockets', {})
if socket_name not in kwargs.get('input_sockets_optional', {})
for socket_name in kwargs.get('input_sockets', set())
if socket_name in kwargs.get('input_sockets_optional', set())
},
stop_propagation=stop_propagation,
),
**kwargs,
)
@ -388,45 +714,63 @@ def on_value_changed(
def computes_output_socket(
output_socket_name: ct.SocketName | None,
kind: ct.FlowKind = ct.FlowKind.Value,
kind: FK = FK.Value,
**kwargs,
):
"""Declare a method used to compute the value of a particular output `socket|kind`.
The method's dependencies on properties, input/output `socket|kind`s, and loose input/output `socket|active_kind`s are recorded in its `callback_info`, such that the associated output socket cache can be appropriately invalidated whenever an input dependency changes.
"""
input_socket_kinds = kwargs.get('input_socket_kinds', {})
depon_inscks_kinds = {
insck: input_socket_kinds.get(insck, FK.Value)
for insck in kwargs.get('input_sockets', set())
} | kwargs.get('inscks_kinds', {})
output_sockets_kinds = kwargs.get('output_sockets_kinds', {})
depon_outscks_kinds = {
outsck: output_sockets_kinds.get(outsck, FK.Value)
for outsck in kwargs.get('output_sockets', set())
} | kwargs.get('outscks_kinds', {})
log.debug(
[
output_socket_name,
kind,
kwargs.get('props', set()),
mk_sockets_kinds(depon_inscks_kinds),
mk_sockets_kinds(depon_outscks_kinds),
kwargs.get('all_loose_input_sockets', False),
kwargs.get('all_loose_output_sockets', False),
]
)
return event_decorator(
event=ct.FlowEvent.OutputRequested,
callback_info=InfoOutputRequested(
output_socket_name=output_socket_name,
kind=kind,
# Dependency: Props
depon_props=kwargs.get('props', set()),
depon_input_sockets=kwargs.get('input_sockets', set()),
depon_input_socket_kinds=kwargs.get('input_socket_kinds', {}),
depon_output_sockets=kwargs.get('output_sockets', set()),
depon_output_socket_kinds=kwargs.get('output_socket_kinds', {}),
# Dependency: Input Sockets
depon_input_sockets_kinds=mk_sockets_kinds(depon_inscks_kinds),
# Dependency: Output Sockets
depon_output_sockets_kinds=mk_sockets_kinds(depon_outscks_kinds),
# Dependency: Loose Sockets
depon_all_loose_input_sockets=kwargs.get('all_loose_input_sockets', False),
depon_all_loose_output_sockets=kwargs.get(
'all_loose_output_sockets', False
),
),
**kwargs, ## stop_propagation has no effect.
)
def on_show_preview(
**kwargs,
):
return event_decorator(
event=ct.FlowEvent.ShowPreview,
callback_info={},
**kwargs,
)
def on_show_plot(
stop_propagation: bool = True,
**kwargs,
):
"""Declare a method that reacts to a request to show a plot."""
return event_decorator(
event=ct.FlowEvent.ShowPlot,
callback_info={},
stop_propagation=stop_propagation,
callback_info=None,
**kwargs,
)

View File

@ -18,21 +18,16 @@ from . import (
constants,
file_importers,
scene,
wave_constant,
web_importers,
)
# from . import file_importers
BL_REGISTER = [
*wave_constant.BL_REGISTER,
*scene.BL_REGISTER,
*constants.BL_REGISTER,
*web_importers.BL_REGISTER,
*file_importers.BL_REGISTER,
*scene.BL_REGISTER,
]
BL_NODES = {
**wave_constant.BL_NODES,
**scene.BL_NODES,
**constants.BL_NODES,
**web_importers.BL_NODES,

View File

@ -20,6 +20,9 @@ from .... import contracts as ct
from .... import sockets
from ... import base, events
FK = ct.FlowKind
FS = ct.FlowSignal
class BlenderConstantNode(base.MaxwellSimNode):
node_type = ct.NodeType.BlenderConstant
@ -47,7 +50,7 @@ class BlenderConstantNode(base.MaxwellSimNode):
####################
# - Callbacks
####################
@events.computes_output_socket('Value', input_sockets={'Value'})
@events.computes_output_socket('Value', kind=FK.Params, input_sockets={'Value'})
def compute_value(self, input_sockets) -> typ.Any:
return input_sockets['Value']

View File

@ -16,55 +16,55 @@
import typing as typ
import sympy as sp
from .... import contracts as ct
from .... import sockets
from ... import base, events
FK = ct.FlowKind
class ExprConstantNode(base.MaxwellSimNode):
"""An expression constant."""
node_type = ct.NodeType.ExprConstant
bl_label = 'Expr Constant'
input_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
active_kind=FK.Func,
show_name_selector=True,
),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Func,
active_kind=FK.Func,
show_info_columns=True,
),
}
## TODO: Allow immediately realizing any symbol, or just passing it along.
## TODO: Alter output physical_type when the input PhysicalType changes.
####################
# - FlowKinds
####################
@events.computes_output_socket(
# Trigger
'Expr',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
input_sockets={'Expr'},
inscks_kinds={'Expr': FK.Value},
)
def compute_value(self, input_sockets: dict) -> typ.Any:
"""Compute the symbolic expression value."""
return input_sockets['Expr']
@events.computes_output_socket(
# Trigger
'Expr',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Func},
inscks_kinds={'Expr': FK.Func},
)
def compute_lazy_func(self, input_sockets: dict) -> typ.Any:
"""Compute the lazy expression function."""
return input_sockets['Expr']
####################
@ -73,23 +73,25 @@ class ExprConstantNode(base.MaxwellSimNode):
@events.computes_output_socket(
# Trigger
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_socket_kinds={'Expr': FK.Info},
)
def compute_info(self, input_sockets: dict) -> typ.Any:
"""Compute the tracking information flow."""
return input_sockets['Expr']
@events.computes_output_socket(
# Trigger
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params},
input_socket_kinds={'Expr': FK.Params},
)
def compute_params(self, input_sockets: dict) -> typ.Any:
"""Compute the fnuction parameters."""
return input_sockets['Expr']

View File

@ -14,11 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `ScientificConstantNode`."""
import typing as typ
import bpy
import sympy as sp
import sympy.physics.units as spu
from frozendict import frozendict
from blender_maxwell.utils import bl_cache, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
@ -55,6 +57,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
self,
edit_text: str,
):
"""Search for all valid scientific constants."""
return [
name
for name in sci_constants.SCI_CONSTANTS
@ -82,7 +85,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
self.sci_constant_name,
self.sci_constant,
unit,
is_constant=True,
new_domain=spux.BlessedSet(sp.FiniteSet(self.sci_constant)),
)
return None
@ -91,11 +94,13 @@ class ScientificConstantNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self):
"""Match the node's header label to the active scientific constant."""
if self.sci_constant_str:
return f'Const: {self.sci_constant_str}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Provide a node user-interface allowing the user to specify a scientific constant."""
col.prop(self, self.blfields['sci_constant_str'], text='')
row = col.row(align=True)
@ -139,7 +144,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
row.label(text=f'{self.sci_constant_info["uncertainty"]}')
####################
# - Output
# - FlowKind.Value
####################
@events.computes_output_socket(
'Expr',
@ -148,8 +153,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
def compute_value(self, props) -> typ.Any:
sci_constant = props['sci_constant']
sci_constant_sym = props['sci_constant_sym']
use_symbol = props['use_symbol']
if props['use_symbol'] and sci_constant_sym is not None:
if use_symbol and sci_constant_sym is not None:
return sci_constant_sym.sp_symbol
if sci_constant is not None:
@ -157,6 +163,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
@ -178,19 +187,9 @@ class ScientificConstantNode(base.MaxwellSimNode):
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'sci_constant_sym'},
)
def compute_info(self, props: dict) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None:
return ct.InfoFlow(output=sci_constant_sym)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
@ -206,12 +205,30 @@ class ScientificConstantNode(base.MaxwellSimNode):
func_args=[sci_constant_sym.sp_symbol],
symbols={sci_constant_sym},
).realize_partial(
{
sci_constant_sym: sci_constant,
}
frozendict(
{
sci_constant_sym: sci_constant,
}
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'sci_constant_sym'},
)
def compute_info(self, props: dict) -> typ.Any:
"""Simple `FuncFlow` that computes the symbol value, with output units tracked correctly."""
sci_constant_sym = props['sci_constant_sym']
if sci_constant_sym is not None:
return ct.InfoFlow(output=sci_constant_sym)
return ct.FlowSignal.FlowPending
####################
# - Blender Registration

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `SymbolConstantNode`."""
import enum
import typing as typ
from fractions import Fraction
@ -32,6 +34,8 @@ log = logger.get(__name__)
class SymbolConstantNode(base.MaxwellSimNode):
"""A symbol usable as itself, or as a symbol."""
node_type = ct.NodeType.SymbolConstant
bl_label = 'Symbol'
@ -87,7 +91,7 @@ class SymbolConstantNode(base.MaxwellSimNode):
return None
@property
@bl_cache.cached_bl_property(depends_on={'unit'})
def unit_factor(self) -> spux.Unit | None:
"""Like `self.unit`, except `1` instead of `None` when unitless."""
return sp.Integer(1) if self.unit is None else self.unit
@ -95,6 +99,8 @@ class SymbolConstantNode(base.MaxwellSimNode):
####################
# - Domain
####################
exclude_zero: bool = bl_cache.BLField(False)
interval_finite_z: tuple[int, int] = bl_cache.BLField((0, 1))
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = bl_cache.BLField(
((0, 1), (1, 1))
@ -146,22 +152,27 @@ class SymbolConstantNode(base.MaxwellSimNode):
'preview_value_q',
'preview_value_re',
'preview_value_im',
'unit_factor',
}
)
def preview_value(
def preview_value_phy(
self,
) -> int | Fraction | float | complex:
"""Return the appropriate finite interval from the UI, as guided by `self.mathtype`."""
MT = spux.MathType
match self.mathtype:
case MT.Integer:
return self.preview_value_z
preview_value = sp.Integer(self.preview_value_z)
case MT.Rational:
return Fraction(*self.preview_value_q)
preview_value = sp.Rational(*self.preview_value_q)
case MT.Real:
return self.preview_value_re
preview_value = sp.Float(self.preview_value_re, 4)
case MT.Complex:
return complex(self.preview_value_re, self.preview_value_im)
preview_value = sp.Float(self.preview_value_re) + sp.I * sp.Float(
self.preview_value_im
)
return preview_value * self.unit_factor
@bl_cache.cached_bl_property(
depends_on={
@ -179,11 +190,27 @@ class SymbolConstantNode(base.MaxwellSimNode):
"""Deduce the domain specified in the UI."""
MT = spux.MathType
match self.mathtype:
case MT.Integer | MT.Real | MT.Rational:
return sim_symbols.mk_interval(
self.interval_finite,
self.interval_inf,
self.interval_closed,
case MT.Integer:
start = (
self.interval_inf[0]
if self.interval_inf[0]
else self.interval_finite[0]
)
end = (
self.interval_inf[1]
if self.interval_inf[1]
else self.interval_finite[1]
)
return spux.BlessedSet(sp.Range(start, end, 1))
case MT.Real | MT.Rational:
return spux.BlessedSet(
sim_symbols.mk_interval(
self.interval_finite,
self.interval_inf,
self.interval_closed,
)
)
case MT.Complex:
@ -198,7 +225,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
self.interval_inf_im,
self.interval_closed_im,
)
return sp.ComplexRegion(domain_re, domain_im, polar=False)
return spux.BlessedSet(spux.ComplexRegion(domain_re, domain_im))
@bl_cache.cached_bl_property(depends_on={'domain'})
def is_nonzero(self) -> bool:
return 0 in self.domain
####################
# - Computed Properties
@ -211,7 +242,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
'unit',
'size',
'domain',
'preview_value',
}
)
def symbol(self) -> sim_symbols.SimSymbol:
@ -224,7 +254,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
rows=self.size.rows,
cols=self.size.cols,
domain=self.domain,
preview_value=self.preview_value,
)
####################
@ -268,6 +297,14 @@ class SymbolConstantNode(base.MaxwellSimNode):
row.label(text='Domain - Closure')
row = col.row(align=True)
match self.mathtype:
case spux.MathType.Rational | spux.MathType.Real:
row.prop(self, self.blfields['interval_closed'], text='')
case spux.MathType.Complex:
row.prop(self, self.blfields['interval_closed'], text='')
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_closed'], text='')
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
@ -353,7 +390,6 @@ class SymbolConstantNode(base.MaxwellSimNode):
)
def compute_info(self, props) -> typ.Any:
return ct.InfoFlow(
dims={props['symbol']: None},
output=props['symbol'],
)
@ -362,14 +398,14 @@ class SymbolConstantNode(base.MaxwellSimNode):
'Expr',
kind=ct.FlowKind.Params,
# Loaded
props={'symbol'},
props={'symbol', 'preview_value_phy'},
)
def compute_params(self, props) -> typ.Any:
sym = props['symbol']
return ct.ParamsFlow(
arg_targets=[sym],
func_args=[sym.sp_symbol],
symbols={sym},
previewed_symbols={sym: props['preview_value_phy']},
)

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `DataFileImporterNode`."""
import enum
import typing as typ
from pathlib import Path
@ -31,11 +33,15 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
####################
# - Node
####################
class DataFileImporterNode(base.MaxwellSimNode):
"""Import expression data from a file."""
node_type = ct.NodeType.DataFileImporter
bl_label = 'Data File Importer'
@ -43,7 +49,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
'File Path': sockets.FilePathSocketDef(),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
}
####################
@ -54,7 +60,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
socket_name={'File Path'},
# Loaded
input_sockets={'File Path'},
input_socket_kinds={'File Path': ct.FlowKind.Value},
input_socket_kinds={'File Path': FK.Value},
input_sockets_optional={'File Path': True},
# Flow
## -> See docs in TransformMathNode
@ -69,9 +75,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property()
def file_path(self) -> Path:
"""Retrieve the input file path."""
file_path = self._compute_input(
'File Path', kind=ct.FlowKind.Value, optional=True
)
file_path = self._compute_input('File Path', kind=FK.Value, optional=True)
has_file_path = not ct.FlowSignal.check(file_path)
if has_file_path:
return file_path
@ -99,7 +103,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
)
def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the output expression's `InfoFlow`."""
info = self.compute_output('Expr', kind=ct.FlowKind.Info)
info = self.compute_output('Expr', kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info
@ -120,6 +124,15 @@ class DataFileImporterNode(base.MaxwellSimNode):
cb_depends_on={'output_physical_type'},
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
"""Determine valid units based on the given physical type."""
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
sim_symbols.SimSymbolName.LowerA
)
@ -139,14 +152,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
sim_symbols.SimSymbolName.LowerF
)
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
return []
####################
# - UI
####################
@ -196,10 +201,15 @@ class DataFileImporterNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'},
input_sockets={'File Path'},
props={
'output_name',
'output_mathtype',
'output_physical_type',
'output_unit',
},
inscks_kinds={'File Path': FK.Func},
)
def compute_func(self, props, input_sockets) -> td.Simulation:
"""Declare a lazy, composable function that returns the loaded data.
@ -244,7 +254,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
kind=FK.Params,
)
def compute_params(self) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
@ -256,12 +266,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
kind=FK.Info,
# Loaded
props={'output_name', 'output_mathtype', 'output_physical_type', 'output_unit'}
| {f'dim_{i}_name' for i in range(6)},
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.Func},
output_socket_kinds={'Expr': FK.Func},
)
def compute_info(self, props, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape.
@ -284,9 +294,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
dims = {
sim_symbols.idx(None).update(
sym_name=props[f'dim_{i}_name'],
interval_finite_z=(0, elements),
interval_inf=(False, False),
interval_closed=(True, True),
domain=spux.BlessedSet(sp.Range(0, elements)),
): [str(j) for j in range(elements)]
for i, elements in enumerate(shape)
}

View File

@ -14,14 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import functools
import typing as typ
from pathlib import Path
import bpy
import tidy3d as td
import tidy3d.plugins.dispersion as td_dispersion
from blender_maxwell.utils import logger
from blender_maxwell.utils import bl_cache, logger
from .... import contracts as ct
from .... import managed_objs, sockets
@ -29,35 +30,94 @@ from ... import base, events
log = logger.get(__name__)
VALID_FILE_EXTS = {
'SIMULATION_DATA': {
'.hdf5.gz',
'.hdf5',
},
'SIMULATION': {
'.hdf5.gz',
'.hdf5',
'.json',
'.yaml',
},
'MEDIUM': {
'.hdf5.gz',
'.hdf5',
'.json',
'.yaml',
},
'EXPERIM_DISP_MEDIUM': {
'.txt',
},
}
FK = ct.FlowKind
FS = ct.FlowSignal
CACHE = {}
class ValidTDFileExts(enum.StrEnum):
"""Valid importable Tidy3D file extensions."""
SimData = enum.auto()
Sim = enum.auto()
Medium = enum.auto()
@staticmethod
def to_name(value: typ.Self) -> str:
VFE = ValidTDFileExts
return {
VFE.SimData: 'Sim Data',
VFE.Sim: 'Sim',
VFE.Medium: 'Medium',
}[value]
@staticmethod
def to_icon(_: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
return (
str(self),
ValidTDFileExts.to_name(self),
ValidTDFileExts.to_name(self),
ValidTDFileExts.to_icon(self),
i,
)
####################
# - Properties
####################
@functools.cached_property
def valid_exts(self) -> set[str]:
VFE = ValidTDFileExts
match self:
case VFE.SimData:
return {
'.hdf5.gz',
'.hdf5',
}
case VFE.Sim | VFE.Medium:
return {
'.hdf5.gz',
'.hdf5',
'.json',
'.yaml',
}
raise TypeError
@functools.cached_property
def td_type(self) -> set[str]:
"""The corresponding Tidy3D type."""
VFE = ValidTDFileExts
return {
VFE.SimData: td.SimulationData,
VFE.Sim: td.Simulation,
VFE.Medium: td.Medium,
}[self]
####################
# - Methods
####################
def is_path_compatible(self, path: Path) -> bool:
ext_matches = ''.join(path.suffixes) in self.valid_exts
return ext_matches and path.is_file()
def load(self, file_path: Path) -> typ.Any:
VFE = ValidTDFileExts
match self:
case VFE.SimData | VFE.Sim | VFE.Medium:
return self.td_type.from_file(str(file_path))
raise TypeError
####################
# - Node
####################
class Tidy3DFileImporterNode(base.MaxwellSimNode):
"""Import a simulation design or analysis element from a Tidy3D object."""
node_type = ct.NodeType.Tidy3DFileImporter
bl_label = 'Tidy3D File Importer'
@ -71,185 +131,74 @@ class Tidy3DFileImporterNode(base.MaxwellSimNode):
####################
# - Properties
####################
## TODO: More automatic determination of which file type is in use :)
tidy3d_type: bpy.props.EnumProperty(
name='Tidy3D Type',
description='Type of Tidy3D object to load',
items=[
(
'SIMULATION_DATA',
'Sim Data',
'Data from Completed Tidy3D Simulation',
),
('SIMULATION', 'Sim', 'Tidy3D Simulation'),
('MEDIUM', 'Medium', 'A Tidy3D Medium'),
(
'EXPERIM_DISP_MEDIUM',
'Experim Disp Medium',
'A pole-residue fit of experimental dispersive medium data, described by a .txt file specifying wl, n, k',
),
],
default='SIMULATION_DATA',
update=lambda self, context: self.on_prop_changed('tidy3d_type', context),
)
disp_fit__min_poles: bpy.props.IntProperty(
name='min Poles',
description='Min. # poles to fit to the experimental dispersive medium data',
default=1,
)
disp_fit__max_poles: bpy.props.IntProperty(
name='max Poles',
description='Max. # poles to fit to the experimental dispersive medium data',
default=5,
)
## TODO: Bool of whether to fit eps_inf, with conditional choice of eps_inf as socket
disp_fit__tolerance_rms: bpy.props.FloatProperty(
name='Max RMS',
description='The RMS error threshold, below which the fit should be considered converged',
default=0.001,
precision=5,
)
## TODO: "AdvanceFastFitterParam" options incl. loss_bounds, weights, show_progress, show_unweighted_rms, relaxed, smooth, logspacing, numiters, passivity_num_iters, and slsqp_constraint_scale
tidy3d_type: ValidTDFileExts = bl_cache.BLField(ValidTDFileExts.SimData)
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, 'tidy3d_type', text='')
if self.tidy3d_type == 'EXPERIM_DISP_MEDIUM':
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Pole-Residue Fit')
col.prop(self, 'disp_fit__min_poles')
col.prop(self, 'disp_fit__max_poles')
col.prop(self, 'disp_fit__tolerance_rms')
####################
# - Event Methods: Output Data
####################
def _compute_sim_data_for(
self, output_socket_name: str, file_path: Path
) -> td.components.base.Tidy3dBaseModel:
return {
'Sim': td.Simulation,
'Sim Data': td.SimulationData,
'Medium': td.Medium,
}[output_socket_name].from_file(str(file_path))
@events.computes_output_socket(
'Sim',
input_sockets={'File Path'},
)
def compute_sim(self, input_sockets: dict) -> td.Simulation:
return self._compute_sim_data_for('Sim', input_sockets['File Path'])
@events.computes_output_socket(
'Sim Data',
input_sockets={'File Path'},
)
def compute_sim_data(self, input_sockets: dict) -> td.SimulationData:
return self._compute_sim_data_for('Sim Data', input_sockets['File Path'])
@events.computes_output_socket(
'Medium',
input_sockets={'File Path'},
)
def compute_medium(self, input_sockets: dict) -> td.Medium:
return self._compute_sim_data_for('Medium', input_sockets['File Path'])
####################
# - Event Methods: Output Data | Dispersive Media
####################
@events.computes_output_socket(
'Experim Disp Medium',
input_sockets={'File Path'},
)
def compute_experim_disp_medium(self, input_sockets: dict) -> td.Medium:
if CACHE.get(self.bl_label) is not None:
log.debug('Reusing Cached Dispersive Medium')
return CACHE[self.bl_label]['model']
log.info('Loading Experimental Data')
dispersion_fitter = td_dispersion.FastDispersionFitter.from_file(
str(input_sockets['File Path'])
)
log.info('Computing Fast Dispersive Fit of Experimental Data...')
pole_residue_medium, rms_error = dispersion_fitter.fit(
min_num_poles=self.disp_fit__min_poles,
max_num_poles=self.disp_fit__max_poles,
tolerance_rms=self.disp_fit__tolerance_rms,
)
log.info('Fit Succeeded w/RMS "%s"!', f'{rms_error:.5f}')
# Populate Cache
CACHE[self.bl_label] = {}
CACHE[self.bl_label]['model'] = pole_residue_medium
CACHE[self.bl_label]['fitter'] = dispersion_fitter
CACHE[self.bl_label]['rms_error'] = rms_error
return pole_residue_medium
####################
# - Event Methods: Setup Output Socket
####################
@events.on_value_changed(
socket_name='File Path',
prop_name='tidy3d_type',
input_sockets={'File Path'},
run_on_init=True,
# Loaded
props={'tidy3d_type'},
)
def on_file_changed(self, input_sockets: dict, props: dict):
if CACHE.get(self.bl_label) is not None:
del CACHE[self.bl_label]
file_ext = ''.join(input_sockets['File Path'].suffixes)
if not (
input_sockets['File Path'].is_file()
and file_ext in VALID_FILE_EXTS[props['tidy3d_type']]
):
self.loose_output_sockets = {}
else:
self.loose_output_sockets = {
'SIMULATION_DATA': {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
},
'SIMULATION': {'Sim': sockets.MaxwellFDTDSimSocketDef()},
'MEDIUM': {'Medium': sockets.MaxwellMediumSocketDef()},
'EXPERIM_DISP_MEDIUM': {
'Experim Disp Medium': sockets.MaxwellMediumSocketDef()
},
}[props['tidy3d_type']]
def on_file_changed(self, props) -> None:
self.loose_output_sockets = {
ValidTDFileExts.SimData: {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
},
ValidTDFileExts.Sim: {'Sim': sockets.MaxwellFDTDSimSocketDef()},
ValidTDFileExts.Medium: {'Medium': sockets.MaxwellMediumSocketDef()},
}[props['tidy3d_type']]
####################
# - Event Methods: Plot
# - FlowKind.Value
####################
@events.on_show_plot(
managed_objs={'plot'},
def _compute_td_obj(
self, props, input_sockets
) -> td.components.base.Tidy3dBaseModel:
tidy3d_type = props['tidy3d_type']
file_path = input_sockets['File Path']
if tidy3d_type.is_path_compatible(file_path):
return tidy3d_type.load(file_path)
return FS.FlowPending
@events.computes_output_socket(
'Sim',
kind=FK.Value,
# Loaded
props={'tidy3d_type'},
inscks_kinds={'File Path': FK.Value},
)
def on_show_plot(
self,
props: dict,
managed_objs: dict,
):
"""When the filetype is 'Experimental Dispersive Medium', plot the computed model against the input data."""
if props['tidy3d_type'] == 'EXPERIM_DISP_MEDIUM':
# Populate Cache
if CACHE.get(self.bl_label) is None:
model_medium = self.compute_experim_disp_medium()
disp_fitter = CACHE[self.bl_label]['fitter']
else:
model_medium = CACHE[self.bl_label]['model']
disp_fitter = CACHE[self.bl_label]['fitter']
def compute_sim(self, props, input_sockets) -> td.Simulation:
return self._compute_td_obj(props, input_sockets)
# Plot
managed_objs['plot'].mpl_plot_to_image(
lambda ax: disp_fitter.plot(
medium=model_medium,
ax=ax,
),
bl_select=True,
)
@events.computes_output_socket(
'Sim Data',
kind=FK.Value,
# Loaded
props={'tidy3d_type'},
inscks_kinds={'File Path': FK.Value},
)
def compute_sim_data(self, props, input_sockets) -> td.SimulationData:
return self._compute_td_obj(props, input_sockets)
@events.computes_output_socket(
'Medium',
kind=FK.Value,
# Loaded
props={'tidy3d_type'},
inscks_kinds={'File Path': FK.Value},
)
def compute_medium(self, props, input_sockets: dict) -> td.Medium:
return self._compute_td_obj(props, input_sockets)
####################

View File

@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import tidy3d_web_runner
from . import tidy3d_web_importer
BL_REGISTER = [
*tidy3d_web_runner.BL_REGISTER,
*tidy3d_web_importer.BL_REGISTER,
]
BL_NODES = {
**tidy3d_web_runner.BL_NODES,
**tidy3d_web_importer.BL_NODES,
}

View File

@ -0,0 +1,217 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `Tidy3DWebImporterNode`."""
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
SimDataArray: typ.TypeAlias = dict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.SimulationData
]
SimDataArrayInfo: typ.TypeAlias = dict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], typ.Any
]
####################
# - Node
####################
class Tidy3DWebImporterNode(base.MaxwellSimNode):
"""Retrieve a simulation w/data from the Tidy3D cloud service."""
node_type = ct.NodeType.Tidy3DWebImporter
bl_label = 'Tidy3D Web Importer'
input_sockets: typ.ClassVar = {
'Preview Sim': sockets.MaxwellFDTDSimSocketDef(),
}
input_socket_sets: typ.ClassVar = {
'Single': {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=True,
),
},
'Batch': {
'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Array,
should_exist=True,
),
},
}
output_socket_sets: typ.ClassVar = {
'Single': {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
},
'Batch': {
'Sim Datas': sockets.MaxwellFDTDSimDataSocketDef(
active_kind=FK.Array,
),
},
}
####################
# - Properties: Cloud Tasks -> Sim Datas
####################
@events.on_value_changed(
# Trigger
socket_name={'Cloud Task': FK.Value, 'Cloud Tasks': FK.Array},
)
def on_cloud_tasks_changed(self) -> None: # noqa: D102
self.cloud_tasks = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property(depends_on={'active_socket_set'})
def cloud_tasks(self) -> list[tdcloud.CloudTask] | None:
"""Retrieve the current cloud tasks from the input.
If one can't be loaded, return None.
"""
if self.active_socket_set == 'Single':
cloud_task_single = self._compute_input(
'Cloud Task',
kind=FK.Value,
)
has_cloud_task_single = not FS.check(cloud_task_single)
if has_cloud_task_single:
return [cloud_task_single]
if self.active_socket_set == 'Batch':
cloud_task_array = self._compute_input(
'Cloud Tasks',
kind=FK.Array,
)
has_cloud_task_array = not FS.check(cloud_task_array)
if has_cloud_task_array:
return cloud_task_array
return None
@bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
def task_infos(self) -> list[tdcloud.CloudTaskInfo | None] | None:
"""Retrieve the current cloud task information from the input socket.
If it can't be loaded, return None.
"""
if self.cloud_tasks is not None:
task_infos = [
tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
for cloud_task in self.cloud_tasks
]
if task_infos:
return task_infos
return None
@bl_cache.cached_bl_property(depends_on={'cloud_tasks', 'task_infos'})
def sim_datas(self) -> SimDataArray | None:
"""Retrieve the simulation data of the current cloud task from the input socket.
If it can't be loaded, return None.
"""
cloud_tasks = self.cloud_tasks
task_infos = self.task_infos
if (
cloud_tasks is not None
and task_infos is not None
and all(task_info is not None for task_info in task_infos)
):
sim_datas = {}
for cloud_task, task_info in [
(cloud_task, task_info)
for cloud_task, task_info in zip(cloud_tasks, task_infos, strict=True)
if task_info.status == 'success'
]:
sim_data = tdcloud.TidyCloudTasks.download_task_sim_data(
cloud_task,
task_info.disk_cache_path(ct.addon.prefs().addon_cache_path),
)
if sim_data is not None:
sim_metadata = ct.SimMetadata.from_sim(sim_data)
sim_datas |= {sim_metadata.syms_vals: sim_data}
if sim_datas:
return sim_datas
return None
####################
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
"""Draw information about the cloud connection."""
tdcloud.draw_cloud_status(layout)
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Sim Data',
kind=FK.Value,
# Loaded
props={'sim_datas'},
)
def compute_sim_data(self, props) -> td.SimulationData | FS:
"""A single simulation data object, when there only is one."""
sim_datas = props['sim_datas']
if sim_datas is not None and len(sim_datas) == 1:
return next(iter(sim_datas.values()))
return FS.FlowPending
####################
# - FlowKind.Array
####################
@events.computes_output_socket(
'Sim Datas',
kind=FK.Array,
# Loaded
props={'sim_datas'},
)
def compute_sim_datas(self, props) -> SimDataArray | FS:
"""All simulation data objects, for when there are more than one.
Generally part of the same batch.
"""
sim_datas = props['sim_datas']
if sim_datas is not None and len(sim_datas) > 1:
return sim_datas
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = [
Tidy3DWebImporterNode,
]
BL_NODES = {
ct.NodeType.Tidy3DWebImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_WEBIMPORTERS)
}

View File

@ -1,276 +0,0 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import bl_cache, logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Operators
####################
class RunSimulation(bpy.types.Operator):
"""Run a Tidy3D simulation accessible from a `Tidy3DWebImporterNode`."""
bl_idname = ct.OperatorType.NodeRunSimulation
bl_label = 'Run Sim'
bl_description = 'Run the currently tracked simulation task'
@classmethod
def poll(cls, context):
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
# Check Tidy3DWebImporterNode is Accessible
and hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebImporter
# Check Task is Runnable
and context.node.is_task_runnable
)
def execute(self, context):
node = context.node
node.cloud_task.submit()
return {'FINISHED'}
class ReloadTrackedTask(bpy.types.Operator):
"""Reload information of the selected task in a `Tidy3DWebImporterNode`."""
bl_idname = ct.OperatorType.NodeReloadTrackedTask
bl_label = 'Reload Tracked Tidy3D Cloud Task'
bl_description = 'Reload the currently tracked simulation task'
@classmethod
def poll(cls, context):
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
# Check Tidy3DWebImporterNode is Accessible
and hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebImporter
)
def execute(self, context):
node = context.node
tdcloud.TidyCloudTasks.update_task(node.cloud_task)
return {'FINISHED'}
####################
# - Node
####################
class Tidy3DWebImporterNode(base.MaxwellSimNode):
node_type = ct.NodeType.Tidy3DWebImporter
bl_label = 'Tidy3D Web Runner/Importer'
input_sockets: typ.ClassVar = {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=True, ## Ensure it is never NewSimCloudTask
),
}
output_sockets: typ.ClassVar = {
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
}
####################
# - Computed (Cached)
####################
@property
def cloud_task(self) -> tdcloud.CloudTask | None:
"""Retrieve the current cloud task from the input socket.
If one can't be loaded, return None.
"""
cloud_task = self._compute_input(
'Cloud Task',
kind=ct.FlowKind.Value,
)
has_cloud_task = not ct.FlowSignal.check(cloud_task)
if has_cloud_task:
return cloud_task
return None
@property
def task_info(self) -> tdcloud.CloudTaskInfo | None:
"""Retrieve the current cloud task information from the input socket.
If it can't be loaded, return None.
"""
cloud_task = self.cloud_task
if cloud_task is None:
return None
# Retrieve Task Info
task_info = tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
if task_info is None:
return None
return task_info
@property
def sim_data(self) -> td.Simulation | None:
"""Retrieve the simulation data of the current cloud task from the input socket.
If it can't be loaded, return None.
"""
task_info = self.task_info
if task_info is None:
return None
if task_info.status == 'success':
# Download Sim Data
## -> self.cloud_task really shouldn't be able to be None here.
## -> So, we check it by applying the Ostrich method.
sim_data = tdcloud.TidyCloudTasks.download_task_sim_data(
self.cloud_task,
tdcloud.TidyCloudTasks.task_info(
self.cloud_task.task_id
).disk_cache_path(ct.addon.prefs().addon_cache_path),
)
if sim_data is None:
return None
return sim_data
return None
####################
# - Computed (Uncached)
####################
@property
def is_task_runnable(self) -> bool:
"""Checks whether all conditions are satisfied to be able to actually run a simulation."""
if self.task_info is not None:
return self.task_info.status == 'draft'
return False
####################
# - UI
####################
def draw_operators(self, context, layout):
# Row: Run Sim Buttons
row = layout.row(align=True)
row.operator(
ct.OperatorType.NodeRunSimulation,
text='Run Sim',
)
def draw_info(self, context, layout):
# Connection Info
auth_icon = 'CHECKBOX_HLT' if tdcloud.IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
conn_icon = 'CHECKBOX_HLT' if tdcloud.IS_ONLINE else 'CHECKBOX_DEHLT'
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Cloud Status')
box = layout.box()
split = box.split(factor=0.85)
## Split: Left Column
col = split.column(align=False)
col.label(text='Authed')
col.label(text='Connected')
## Split: Right Column
col = split.column(align=False)
col.label(icon=auth_icon)
col.label(icon=conn_icon)
# Cloud Task Info
if self.task_info is not None:
# Header
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Task Info')
# Task Run Progress
# row = layout.row(align=True)
# row.progress(
# factor=0.0,
# type='BAR',
# text=f'Status: {self.task_info.status.capitalize()}',
# )
row.operator(
ct.OperatorType.NodeReloadTrackedTask,
text='',
icon='FILE_REFRESH',
)
# Task Information
box = layout.box()
split = box.split(factor=0.4)
## Split: Left Column
col = split.column(align=False)
col.label(text='Status')
col.label(text='Real Cost')
## Split: Right Column
cost_real = (
f'{self.task_info.cost_real:.2f}'
if self.task_info.cost_real is not None
else 'TBD'
)
col = split.column(align=False)
col.alignment = 'RIGHT'
col.label(text=self.task_info.status.capitalize())
col.label(text=f'{cost_real} creds')
####################
# - Output Methods
####################
@events.computes_output_socket(
'Sim Data',
props={'sim_data'},
input_sockets={'Cloud Task'}, ## Keep to respect dependency chains.
)
def compute_sim_data(
self, props, input_sockets
) -> td.SimulationData | ct.FlowSignal:
if props['sim_data'] is None:
return ct.FlowSignal.FlowPending
return props['sim_data']
####################
# - Blender Registration
####################
BL_REGISTER = [
RunSimulation,
ReloadTrackedTask,
Tidy3DWebImporterNode,
]
BL_NODES = {
ct.NodeType.Tidy3DWebImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_WEBIMPORTERS)
}

View File

@ -1,110 +0,0 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import sympy as sp
from .. import contracts as ct
from .. import sockets
from . import base
class KitchenSinkNode(base.MaxwellSimNode):
node_type = ct.NodeType.KitchenSink
bl_label = 'Kitchen Sink'
# bl_icon = ...
####################
# - Sockets
####################
input_sockets = {
'Static Data': sockets.AnySocketDef(),
}
input_socket_sets = {
'Basic': {
'Any': sockets.AnySocketDef(),
'Bool': sockets.BoolSocketDef(),
'FilePath': sockets.FilePathSocketDef(),
'Text': sockets.TextSocketDef(),
},
'Number': {
'Integer': sockets.IntegerNumberSocketDef(),
'Rational': sockets.RationalNumberSocketDef(),
'Real': sockets.RealNumberSocketDef(),
'Complex': sockets.ComplexNumberSocketDef(),
},
'Vector': {
'Real 2D': sockets.Real2DVectorSocketDef(),
'Real 3D': sockets.Real3DVectorSocketDef(
default_value=sp.Matrix([0.0, 0.0, 0.0])
),
'Complex 2D': sockets.Complex2DVectorSocketDef(),
'Complex 3D': sockets.Complex3DVectorSocketDef(),
},
'Physical': {
'Time': sockets.PhysicalTimeSocketDef(),
# "physical_point_2d": sockets.PhysicalPoint2DSocketDef(),
'Angle': sockets.PhysicalAngleSocketDef(),
'Length': sockets.PhysicalLengthSocketDef(),
'Area': sockets.PhysicalAreaSocketDef(),
'Volume': sockets.PhysicalVolumeSocketDef(),
'Point 3D': sockets.PhysicalPoint3DSocketDef(),
##"physical_size_2d": sockets.PhysicalSize2DSocketDef(),
'Size 3D': sockets.PhysicalSize3DSocketDef(),
'Mass': sockets.PhysicalMassSocketDef(),
'Speed': sockets.PhysicalSpeedSocketDef(),
'Accel Scalar': sockets.PhysicalAccelScalarSocketDef(),
'Force Scalar': sockets.PhysicalForceScalarSocketDef(),
# "physical_accel_3dvector": sockets.PhysicalAccel3DVectorSocketDef(),
##"physical_force_3dvector": sockets.PhysicalForce3DVectorSocketDef(),
'Pol': sockets.PhysicalPolSocketDef(),
'Freq': sockets.PhysicalFreqSocketDef(),
},
'Blender': {
'Object': sockets.BlenderObjectSocketDef(),
'Collection': sockets.BlenderCollectionSocketDef(),
'Image': sockets.BlenderImageSocketDef(),
'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
'Text': sockets.BlenderTextSocketDef(),
},
'Maxwell': {
'Source': sockets.MaxwellSourceSocketDef(),
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(),
'Medium Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(),
'Structure': sockets.MaxwellStructureSocketDef(),
'Bound Box': sockets.MaxwellBoundBoxSocketDef(),
'Bound Face': sockets.MaxwellBoundFaceSocketDef(),
'Monitor': sockets.MaxwellMonitorSocketDef(),
'FDTD Sim': sockets.MaxwellFDTDSimSocketDef(),
'Sim Grid': sockets.MaxwellSimGridSocketDef(),
'Sim Grid Axis': sockets.MaxwellSimGridAxisSocketDef(),
},
}
output_sockets = {
'Static Data': sockets.AnySocketDef(),
}
output_socket_sets = input_socket_sets
####################
# - Blender Registration
####################
BL_REGISTER = [
KitchenSinkNode,
]
BL_NODES = {ct.NodeType.KitchenSink: (ct.NodeCategory.MAXWELLSIM_INPUTS)}

View File

@ -14,20 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import library_medium
# from . import pec_medium
# from . import isotropic_medium
# from . import anisotropic_medium
#
# from . import triple_sellmeier_medium
# from . import sellmeier_medium
# from . import pole_residue_medium
# from . import drude_medium
# from . import drude_lorentz_medium
# from . import debye_medium
#
# from . import non_linearities
from . import library_medium, non_linearities, pole_residue_medium
BL_REGISTER = [
*library_medium.BL_REGISTER,
@ -37,12 +34,12 @@ BL_REGISTER = [
#
# *triple_sellmeier_medium.BL_REGISTER,
# *sellmeier_medium.BL_REGISTER,
# *pole_residue_medium.BL_REGISTER,
*pole_residue_medium.BL_REGISTER,
# *drude_medium.BL_REGISTER,
# *drude_lorentz_medium.BL_REGISTER,
# *debye_medium.BL_REGISTER,
#
# *non_linearities.BL_REGISTER,
*non_linearities.BL_REGISTER,
]
BL_NODES = {
**library_medium.BL_NODES,
@ -52,10 +49,10 @@ BL_NODES = {
#
# **triple_sellmeier_medium.BL_NODES,
# **sellmeier_medium.BL_NODES,
# **pole_residue_medium.BL_NODES,
**pole_residue_medium.BL_NODES,
# **drude_medium.BL_NODES,
# **drude_lorentz_medium.BL_NODES,
# **debye_medium.BL_NODES,
#
# **non_linearities.BL_NODES,
**non_linearities.BL_NODES,
}

View File

@ -14,7 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `LibraryMediumNode`."""
import enum
import functools
import typing as typ
import bpy
@ -36,8 +39,14 @@ log = logger.get(__name__)
_mat_lib_iter = iter(td.material_library)
_mat_key = ''
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class VendoredMedium(enum.StrEnum):
"""Static enum of all mediums vendored with the Tidy3D client library."""
# Declare StrEnum of All Tidy3D Mediums
## -> This is a 'for ... in ...', which uses globals as loop variables.
## -> It's a bit of a hack, but very effective.
@ -53,16 +62,23 @@ class VendoredMedium(enum.StrEnum):
@staticmethod
def to_name(v: typ.Self) -> str:
"""UI method to get the name of a vendored medium."""
return td.material_library[v].name
@functools.cached_property
def name(self) -> str:
"""Name of the vendored medium."""
return VendoredMedium.to_name(self)
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icon."""
return ''
####################
# - Medium Properties
####################
@property
@functools.cached_property
def tidy3d_medium_item(self) -> Tidy3DMediumItem:
"""Extracts the Tidy3D "Medium Item", which encapsulates all the provided experimental variants."""
return td.material_library[self]
@ -70,12 +86,12 @@ class VendoredMedium(enum.StrEnum):
####################
# - Medium Variant Properties
####################
@property
@functools.cached_property
def medium_variants(self) -> set[Tidy3DMediumVariant]:
"""Extracts the list of medium variants, each corresponding to a particular experiment in the literature."""
return self.tidy3d_medium_item.variants
@property
@functools.cached_property
def default_medium_variant(self) -> Tidy3DMediumVariant:
"""Extracts the "default" medium variant, as selected by Tidy3D."""
return self.medium_variants[self.tidy3d_medium_item.default]
@ -83,7 +99,7 @@ class VendoredMedium(enum.StrEnum):
####################
# - Enum Helper
####################
@property
@functools.cached_property
def variants_as_bl_enum_elements(self) -> list[ct.BLEnumElement]:
"""Computes a list of variants in a format suitable for use in a dynamic `EnumProperty`.
@ -105,6 +121,8 @@ class VendoredMedium(enum.StrEnum):
class LibraryMediumNode(base.MaxwellSimNode):
"""A pre-defined medium sourced from a particular experiment in the literature."""
node_type = ct.NodeType.LibraryMedium
bl_label = 'Library Medium'
@ -113,7 +131,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {}
output_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Valid Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq,
@ -172,7 +190,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
"""
return spu.convert_to(
sp.Matrix([sp.nsimplify(el) for el in self.medium.frequency_range])
sp.ImmutableMatrix([sp.nsimplify(el) for el in self.medium.frequency_range])
* spu.hertz,
spux.terahertz,
)
@ -180,7 +198,7 @@ class LibraryMediumNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property(depends_on={'freq_range'})
def wl_range(self) -> sp.Expr:
"""Deduce the vacuum wavelength range as a unit-aware (nanometer, for convenience) column vector."""
return sp.Matrix(
return sp.ImmutableMatrix(
self.freq_range.applyfunc(
lambda el: spu.convert_to(
sci_constants.vac_speed_of_light / el, spu.nanometer
@ -216,13 +234,16 @@ class LibraryMediumNode(base.MaxwellSimNode):
# - UI
####################
def draw_label(self) -> str:
"""Show the active medium in the node label."""
return f'Medium: {self.vendored_medium}'
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Dropdowns for a medium, and a particular experimental variant."""
layout.prop(self, self.blfields['vendored_medium'], text='')
layout.prop(self, self.blfields['variant_name'], text='')
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Draw information about a perticular variant of a particular medium."""
box = col.box()
row = box.row(align=True)
@ -243,48 +264,81 @@ class LibraryMediumNode(base.MaxwellSimNode):
box.operator('wm.url_open', text='Link to Data').url = self.data_url
####################
# - Output
# - FlowKind.Value
####################
@events.computes_output_socket(
'Medium',
kind=FK.Value,
# Loaded
props={'medium'},
)
def compute_medium(self, props) -> sp.Expr:
return props['medium']
def compute_medium_value(self, props) -> td.Medium | FS:
"""Directly produce the medium."""
medium = props['medium']
if medium is not None:
return medium
return FS.FlowSignal
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Valid Freqs',
kind=ct.FlowKind.Range,
props={'freq_range'},
'Medium',
kind=FK.Func,
# Loaded
outscks_kinds={'Medium': FK.Value},
)
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
return ct.RangeFlow(
start=spu.scale_to_unit(['freq_range'][0], spux.THz),
stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
scaling=ct.ScalingMode.Lin,
unit=spux.THz,
)
def compute_medium_func(self, output_sockets) -> ct.FuncFlow:
"""Simply bake `Value` into a function."""
return ct.FuncFlow(func=lambda: output_sockets['Medium'])
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Valid WLs',
kind=ct.FlowKind.Range,
props={'wl_range'},
'Medium',
kind=FK.Params,
)
def compute_valid_wls_lazy(self, props) -> sp.Expr:
return ct.RangeFlow(
start=spu.scale_to_unit(['wl_range'][0], spu.nm),
stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
scaling=ct.ScalingMode.Lin,
unit=spu.nm,
)
def compute_medium_params(self) -> ct.ParamsFlow:
"""Return empty parameters for completeness."""
return ct.ParamsFlow()
####################
# - FlowKind.Range
####################
# @events.computes_output_socket(
# 'Valid Freqs',
# kind=ct.FlowKind.Range,
# props={'freq_range'},
# )
# def compute_valid_freqs_lazy(self, props) -> sp.Expr:
# return ct.RangeFlow(
# start=spu.scale_to_unit(['freq_range'][0], spux.THz),
# stop=spu.scale_to_unit(props['freq_range'][1], spux.THz),
# scaling=ct.ScalingMode.Lin,
# unit=spux.THz,
# )
# @events.computes_output_socket(
# 'Valid WLs',
# kind=ct.FlowKind.Range,
# props={'wl_range'},
# )
# def compute_valid_wls_lazy(self, props) -> sp.Expr:
# return ct.RangeFlow(
# start=spu.scale_to_unit(['wl_range'][0], spu.nm),
# stop=spu.scale_to_unit(['wl_range'][0], spu.nm),
# scaling=ct.ScalingMode.Lin,
# unit=spu.nm,
# )
####################
# - Preview
####################
## TODO: Move medium preview to a viz node of some kind
@events.on_show_plot(
managed_objs={'plot'},
props={'medium'},
stop_propagation=True,
)
def on_show_plot(
self,
@ -293,7 +347,9 @@ class LibraryMediumNode(base.MaxwellSimNode):
):
managed_objs['plot'].mpl_plot_to_image(
lambda ax: props['medium'].plot(props['medium'].frequency_range, ax=ax),
bl_select=True,
width_inches=6.0,
height_inches=3.0,
dpi=150,
)
## TODO: Plot based on Wl, not freq.

View File

@ -18,18 +18,17 @@ from . import (
add_non_linearity,
chi_3_susceptibility_non_linearity,
kerr_non_linearity,
two_photon_absorption_non_linearity,
)
BL_REGISTER = [
*add_non_linearity.BL_REGISTER,
*chi_3_susceptibility_non_linearity.BL_REGISTER,
*kerr_non_linearity.BL_REGISTER,
*two_photon_absorption_non_linearity.BL_REGISTER,
# *two_photon_absorption_non_linearity.BL_REGISTER,
]
BL_NODES = {
**add_non_linearity.BL_NODES,
**chi_3_susceptibility_non_linearity.BL_NODES,
**kerr_non_linearity.BL_NODES,
**two_photon_absorption_non_linearity.BL_NODES,
# **two_photon_absorption_non_linearity.BL_NODES,
}

View File

@ -14,8 +14,179 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `AddNonLinearity`."""
import functools
import typing as typ
import tidy3d as td
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class AddNonLinearity(base.MaxwellSimNode):
"""Add non-linearities to a medium, increasing the range of effects that it can encapsulate."""
node_type = ct.NodeType.AddNonLinearity
bl_label = 'Add Non-Linearity'
input_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Iterations': sockets.ExprSocketDef(
mathtype=MT.Integer,
default_value=5,
abs_min=1,
),
}
output_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
}
####################
# - Events
####################
@events.on_value_changed(
any_loose_input_socket=True,
run_on_init=True,
)
def on_inputs_changed(self) -> None:
"""Always create one extra loose input socket off the end of the last linked loose socket."""
# Deduce SocketDef
## -> Cheat by retrieving the class from the output sockets.
SocketDef = sockets.MaxwellMediumNonLinearitySocketDef
## TODO: Move this code to events, so it can be shared w/Combine
# Deduce Current "Filled"
## -> The first linked socket from the end bounds the "filled" region.
## -> The length of that region, plus one, will be the new amount.
total_loose_inputs = len(self.loose_input_sockets)
reverse_linked_idxs = [
i
for i, bl_socket in enumerate(reversed(self.inputs.values()))
if i < total_loose_inputs and bl_socket.is_linked
]
current_filled = total_loose_inputs - (
reverse_linked_idxs[0] if reverse_linked_idxs else total_loose_inputs
)
new_amount = current_filled + 1
# Deduce SocketDef | Current Amount
self.loose_input_sockets = {
'#0': SocketDef(active_kind=FK.Func),
} | {f'#{i}': SocketDef(active_kind=FK.Func) for i in range(1, new_amount)}
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Medium',
kind=FK.Value,
# Loaded
outscks_kinds={
'Medium': {FK.Func, FK.Params},
},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
value = events.realize_known(output_sockets['Medium'])
if value is not None:
return value
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Medium',
kind=FK.Func,
# Loaded
inscks_kinds={
'Medium': FK.Func,
'Iterations': FK.Func,
},
all_loose_input_sockets=True,
loose_input_sockets_kind=FK.Func,
)
def compute_func(self, input_sockets, loose_input_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
medium = input_sockets['Medium']
iterations = input_sockets['Iterations']
funcs = [
non_linearity
for non_linearity in loose_input_sockets.values()
if not FS.check(non_linearity)
]
if funcs:
non_linearities = functools.reduce(
lambda a, b: a | b,
funcs,
)
return (medium | iterations | non_linearities).compose_within(
lambda els: els[0].updated_copy(
nonlinear_spec=td.NonlinearSpec(
num_iters=els[1],
models=els[2] if isinstance(els[2], tuple) else [els[2]],
)
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Medium',
kind=ct.FlowKind.Params,
# Loaded
inscks_kinds={
'Medium': FK.Params,
'Iterations': FK.Params,
},
all_loose_input_sockets=True,
loose_input_sockets_kind=FK.Params,
)
def compute_params(self, input_sockets, loose_input_sockets) -> td.Box:
"""Aggregate the function parameters needed by the box."""
medium = input_sockets['Medium']
iterations = input_sockets['Iterations']
funcs = [
non_linearity
for non_linearity in loose_input_sockets.values()
if not FS.check(non_linearity)
]
if funcs:
non_linearities = functools.reduce(
lambda a, b: a | b,
funcs,
)
return medium | iterations | non_linearities
return ct.FlowSignal.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
AddNonLinearity,
]
BL_NODES = {
ct.NodeType.AddNonLinearity: (ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES)
}

View File

@ -14,8 +14,132 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `ChiThreeSuscepNonLinearity`."""
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class ChiThreeSuscepNonLinearity(base.MaxwellSimNode):
r"""An instantaneous non-linear susceptibility described by a $\chi_3$ parameter.
The model of field-component does not permit component interactions; therefore, it is only valid when the electric field is predominantly polarized along an axis-aligned direction.
Additionally, strong non-linearities may suffer from divergence issues, since an iterative local method is used to resolve the relation.
"""
node_type = ct.NodeType.ChiThreeSuscepNonLinearity
bl_label = 'Chi3 Non-Linearity'
input_sockets: typ.ClassVar = {
'χ₃': sockets.ExprSocketDef(active_kind=FK.Value),
}
output_sockets: typ.ClassVar = {
'Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(
active_kind=FK.Func
),
}
####################
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's reported info.
Parameters:
layout: UI target for drawing.
"""
box = layout.box()
row = box.row()
row.alignment = 'CENTER'
row.label(text='Interpretation')
# Split
split = box.split(factor=0.4, align=False)
## LHS: Parameter Names
col = split.column()
col.alignment = 'RIGHT'
col.label(text='χ₃:')
## RHS: Parameter Units
col = split.column()
col.label(text='um² / V²')
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Value,
# Loaded
outscks_kinds={
'Non-Linearity': {FK.Func, FK.Params},
},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
r"""The value realizes the output function w/output parameters."""
value = events.realize_known(output_sockets['Non-Linearity'])
if value is not None:
return value
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Func,
# Loaded
inscks_kinds={
'χ₃': FK.Func,
},
)
def compute_func(self, input_sockets) -> ct.FuncFlow:
r"""The function encloses the $\chi_3$ parameter in the nonlinear susceptibility."""
chi_3 = input_sockets['χ₃']
return chi_3.compose_within(
lambda _chi_3: td.NonlinearSusceptibility(chi3=_chi_3)
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Params,
# Loaded
inscks_kinds={
'χ₃': FK.Params,
},
)
def compute_params(self, input_sockets) -> ct.FuncFlow:
r"""The function parameters of the non-linearity are identical to that of the $\chi_3$ parameter."""
return input_sockets['χ₃']
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
ChiThreeSuscepNonLinearity,
]
BL_NODES = {
ct.NodeType.ChiThreeSuscepNonLinearity: (
ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES
)
}

View File

@ -14,8 +14,131 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `KerrNonLinearity`."""
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class KerrNonLinearity(base.MaxwellSimNode):
r"""An instantaneous non-linear susceptibility described by a $\chi_3$ parameter.
The model of field-component does not permit component interactions; therefore, it is only valid when the electric field is predominantly polarized along an axis-aligned direction.
Additionally, strong non-linearities may suffer from divergence issues, since an iterative local method is used to resolve the relation.
"""
node_type = ct.NodeType.KerrNonLinearity
bl_label = 'Kerr Non-Linearity'
input_sockets: typ.ClassVar = {
'n₂': sockets.ExprSocketDef(
active_kind=FK.Value,
mathtype=MT.Complex,
),
}
output_sockets: typ.ClassVar = {
'Non-Linearity': sockets.MaxwellMediumNonLinearitySocketDef(
active_kind=FK.Func
),
}
####################
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's reported info.
Parameters:
layout: UI target for drawing.
"""
box = layout.box()
row = box.row()
row.alignment = 'CENTER'
row.label(text='Interpretation')
# Split
split = box.split(factor=0.4, align=False)
## LHS: Parameter Names
col = split.column()
col.alignment = 'RIGHT'
col.label(text='n₂:')
## RHS: Parameter Units
col = split.column()
col.label(text='um² / W')
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Value,
# Loaded
outscks_kinds={
'Non-Linearity': {FK.Func, FK.Params},
},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""The value realizes the output function w/output parameters."""
value = events.realize_known(output_sockets['Non-Linearity'])
if value is not None:
return value
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Func,
# Loaded
inscks_kinds={
'n₂': FK.Func,
},
)
def compute_func(self, input_sockets) -> ct.FuncFlow:
r"""The function encloses the $\chi_3$ parameter in the nonlinear susceptibility."""
n2 = input_sockets['n₂']
return n2.compose_within(lambda _n2: td.KerrNonlinearity(n2=_n2))
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Non-Linearity',
kind=FK.Params,
# Loaded
inscks_kinds={
'n₂': FK.Params,
},
)
def compute_params(self, input_sockets) -> ct.FuncFlow:
r"""The function parameters of the non-linearity are identical to that of the $\chi_3$ parameter."""
return input_sockets['n₂']
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
KerrNonLinearity,
]
BL_NODES = {
ct.NodeType.KerrNonLinearity: (ct.NodeCategory.MAXWELLSIM_MEDIUMS_NONLINEARITIES)
}

View File

@ -14,8 +14,319 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
import bpy
import sympy.physics.units as spu
import tidy3d as td
import tidy3d.plugins.dispersion as td_dispersion
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
PT = spux.PhysicalType
VALID_URL_PREFIXES = {
'https://refractiveindex.info',
}
####################
# - Operators
####################
class FitPoleResidueMedium(bpy.types.Operator):
"""Trigger fitting of a dispersive medium to a Pole-Residue model, and store it on a `PoleResidueMediumnode`."""
bl_idname = ct.OperatorType.NodeFitDispersiveMedium
bl_label = 'Fit Dispersive Medium from Input'
bl_description = (
'Fit the dispersive medium specified by the `PoleResidueMediumNode`.'
)
@classmethod
def poll(cls, context):
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.PoleResidueMedium
# Check Medium is Fittable
and context.node.fitter is not None
and context.node.fitted_medium is None
and context.node.fitted_rms_error is None
)
def execute(self, context):
node = context.node
try:
pole_residue_medium, rms_error = node.fitter.fit(
min_num_poles=node.fit_min_poles,
max_num_poles=node.fit_max_poles,
tolerance_rms=node.fit_tolerance_rms,
)
except: # noqa: E722
self.report(
{'ERROR'}, "Couldn't perform PoleResidue data fit - check inputs."
)
return {'FINISHED'}
else:
node.fitted_medium = pole_residue_medium
node.fitted_rms_error = float(rms_error)
for bl_socket in node.inputs:
bl_socket.trigger_event(ct.FlowEvent.EnableLock)
return {'FINISHED'}
class ReleasePoleResidueFit(bpy.types.Operator):
"""Release a previous fit of a dispersive medium to a Pole-Residue model, from a `PoleResidueMediumnode`."""
bl_idname = ct.OperatorType.NodeReleaseDispersiveFit
bl_label = 'Release Dispersive Medium fit'
bl_description = (
'Release the dispersive medium fit from the `PoleResidueMediumNode`.'
)
@classmethod
def poll(cls, context):
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.PoleResidueMedium
# Check Medium is Fittable
and context.node.fitted_medium is not None
and context.node.fitted_rms_error is not None
)
def execute(self, context):
node = context.node
node.fitted_medium = None
node.fitted_rms_error = None
for bl_socket in node.inputs:
bl_socket.trigger_event(ct.FlowEvent.DisableLock)
return {'FINISHED'}
####################
# - Node
####################
class PoleResidueMediumNode(base.MaxwellSimNode):
"""A dispersive medium described by a pole-residue model."""
node_type = ct.NodeType.PoleResidueMedium
bl_label = 'Pole Residue Medium'
input_socket_sets: typ.ClassVar = {
'Fit URL': {
'URL': sockets.StringSocketDef(),
},
'Fit Data': {
'Expr': sockets.ExprSocketDef(active_kind=FK.Func),
},
}
output_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(),
}
managed_obj_types: typ.ClassVar = {
'plot': managed_objs.ManagedBLImage,
}
####################
# - Properties
####################
fit_min_poles: int = bl_cache.BLField(1)
fit_max_poles: int = bl_cache.BLField(5)
fit_tolerance_rms: float = bl_cache.BLField(0.001)
fitted_medium: td.PoleResidue | None = bl_cache.BLField(None)
fitted_rms_error: float | None = bl_cache.BLField(None)
## TODO: Bool of whether to fit eps_inf, with conditional choice of eps_inf as socket
## TODO: "AdvanceFastFitterParam" options incl. loss_bounds, weights, show_progress, show_unweighted_rms, relaxed, smooth, logspacing, numiters, passivity_num_iters, and slsqp_constraint_scale
####################
# - Data Fitting
####################
@events.on_value_changed(
socket_name={'Expr': {FK.Func, FK.Params, FK.Info}, 'URL': FK.Value},
stop_propagation=True,
)
def on_expr_changed(self) -> None:
"""Respond to changes in `Func`, `Params`, and `Info` to invalidate `self.fitter`."""
self.fitter = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property(depends_on={'active_socket_set'})
def fitter(self) -> td_dispersion.FastDispersionFitter | None:
"""Compute a `FastDispersionFitter`, which can be used to initiate a data-fit."""
match self.active_socket_set:
case 'Fit Data':
func = self._compute_input('Expr', kind=FK.Func)
info = self._compute_input('Expr', kind=FK.Info)
params = self._compute_input('Expr', kind=FK.Params)
has_info = not FS.check(info)
expr = events.realize_known({FK.Func: func, FK.Params: params})
if (
expr is not None
and has_info
and len(info.dims == 1)
and info.first_dim
and info.first_dim.physical_type is PT.Length
):
return td_dispersion.FastDispersionFitter(
wvl_um=info.dims[info.first_dim]
.rescale_to_unit(spu.micrometer)
.values,
n_data=expr.real,
k_data=expr.imag,
)
return None
case 'Fit URL':
url = self._compute_input('URL', kind=FK.Value)
has_url = not FS.check(url)
if has_url and any(
url.startswith(valid_prefix) for valid_prefix in VALID_URL_PREFIXES
):
return None
# return td_dispersion.FastDispersionFitter.from_url(url)
return None
raise TypeError
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
"""Draw loaded properties."""
# Fit/Release Operator
row = col.row(align=True)
row.operator(
ct.OperatorType.NodeFitDispersiveMedium,
text='Fit Medium',
)
if self.fitted_medium is not None:
row.operator(
ct.OperatorType.NodeReleaseDispersiveFit,
icon='LOOP_BACK',
text='',
)
# Fit Parameters / Fit Info
if self.fitted_medium is None:
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='min|max|tol')
col.prop(self, self.blfields['fit_min_poles'])
col.prop(self, self.blfields['fit_max_poles'])
col.prop(self, self.blfields['fit_tolerance_rms'])
else:
box = col.box()
row = box.row(align=True)
row.alignment = 'CENTER'
row.label(text='Fit Info')
split = box.split(factor=0.4)
col = split.column()
row = col.row()
row.label(text='RMS Err')
col = split.column()
row = col.row()
row.label(text=f'{self.fitted_rms_error:.4f}')
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Medium',
kind=FK.Value,
# Loaded
props={'fitted_medium'},
)
def compute_fitted_medium_value(self, props) -> td.Medium | FS:
"""Return the fitted medium."""
fitted_medium = props['fitted_medium']
if fitted_medium is not None:
return fitted_medium
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Medium',
kind=FK.Func,
# Loaded
outscks_kinds={'Medium': FK.Value},
)
def compute_fitted_medium_func(self, output_sockets) -> td.Medium | FS:
"""Return the fitted medium as a function with that medium baked in."""
fitted_medium = output_sockets['Medium']
return ct.FuncFlow(
func=lambda: fitted_medium,
)
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Medium',
kind=FK.Params,
)
def compute_fitted_medium_params(self) -> td.Medium | FS:
"""Declare no function parameters."""
return ct.ParamsFlow()
####################
# - Event Methods: Plot
####################
@events.on_show_plot(
managed_objs={'plot'},
# Loaded
props={'fitter', 'fitted_medium'},
)
def on_show_plot(self, props, managed_objs):
"""When the filetype is 'Experimental Dispersive Medium', plot the computed model against the input data."""
fitter = props['fitter']
fitted_medium = props['fitted_medium']
if fitter is not None and fitted_medium is not None:
managed_objs['plot'].mpl_plot_to_image(
lambda ax: props['fitter'].plot(
medium=props['fitted_medium'],
ax=ax,
),
width_inches=6.0,
height_inches=3.0,
dpi=150,
)
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
FitPoleResidueMedium,
ReleasePoleResidueFit,
PoleResidueMediumNode,
]
BL_NODES = {ct.NodeType.PoleResidueMedium: (ct.NodeCategory.MAXWELLSIM_MEDIUMS)}

View File

@ -31,6 +31,11 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class EHFieldMonitorNode(base.MaxwellSimNode):
"""Node providing for the monitoring of electromagnetic fields within a given planar region or volume."""
@ -45,26 +50,27 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
physical_type=PT.Length,
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([1, 1, 1]),
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Integer,
default_value=sp.Matrix([10, 10, 10]),
abs_min=0,
mathtype=MT.Integer,
default_value=sp.ImmutableMatrix([10, 10, 10]),
abs_min=1,
),
}
input_socket_sets: typ.ClassVar = {
'Freq Domain': {
'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq,
active_kind=FK.Range,
physical_type=PT.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
@ -73,21 +79,22 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
},
'Time Domain': {
't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Time,
active_kind=FK.Range,
physical_type=PT.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
default_steps=0,
default_steps=2,
),
't Stride': sockets.ExprSocketDef(
mathtype=spux.MathType.Integer,
mathtype=MT.Integer,
default_value=100,
abs_min=1,
),
},
}
output_sockets: typ.ClassVar = {
'Monitor': sockets.MaxwellMonitorSocketDef(active_kind=ct.FlowKind.Func),
'Monitor': sockets.MaxwellMonitorSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@ -103,6 +110,11 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the field-selector in the node UI.
Parameters:
layout: UI target for drawing.
"""
layout.prop(self, self.blfields['fields'], expand=True)
####################
@ -110,186 +122,134 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Monitor',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Monitor'},
output_socket_kinds={'Monitor': {ct.FlowKind.Func, ct.FlowKind.Params}},
outscks_kinds={
'Monitor': {FK.Func, FK.Params},
},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Monitor'][ct.FlowKind.Func]
output_params = output_sockets['Monitor'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Realizes the output function w/output parameters."""
value = events.realize_known(output_sockets['Monitor'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'active_socket_set', 'sim_node_name', 'fields'},
input_sockets={
'Center',
'Size',
'Stride',
'Freqs',
't Range',
't Stride',
inscks_kinds={
'Center': FK.Func,
'Size': FK.Func,
'Stride': FK.Func,
'Freqs': FK.Func,
't Range': FK.Func,
't Stride': FK.Func,
},
input_socket_kinds={
'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func,
'Stride': ct.FlowKind.Func,
'Freqs': ct.FlowKind.Func,
't Range': ct.FlowKind.Func,
't Stride': ct.FlowKind.Func,
input_sockets_optional={'Freqs', 't Range', 't Stride'},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
'Freqs': ct.UNITS_TIDY3D,
't Range': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.FieldMonitor:
"""Lazily assembles the FieldMonitor from the input functions."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_stride = not ct.FlowSignal.check(stride)
freqs = input_sockets['Freqs']
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
if has_center and has_size and has_stride:
name = props['sim_node_name']
fields = props['fields']
sim_node_name = props['sim_node_name']
fields = props['fields']
common_func_flow = (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| stride
)
common_func_flow = center | size | stride
match props['active_socket_set']:
case 'Freq Domain' if not FS.check(freqs):
return (common_func_flow | freqs).compose_within(
lambda els: td.FieldMonitor(
name=sim_node_name,
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
interval_space=els[2].flatten().tolist(),
freqs=els[3].flatten(),
fields=fields,
)
)
match props['active_socket_set']:
case 'Freq Domain':
freqs = input_sockets['Freqs']
has_freqs = not ct.FlowSignal.check(freqs)
if has_freqs:
return (
common_func_flow
| freqs.scale_to_unit_system(ct.UNITS_TIDY3D)
).compose_within(
lambda els: td.FieldMonitor(
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
name=name,
interval_space=els[2].flatten().tolist(),
freqs=els[3].flatten(),
fields=fields,
)
)
case 'Time Domain':
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
has_t_range = not ct.FlowSignal.check(t_range)
has_t_stride = not ct.FlowSignal.check(t_stride)
if has_t_range and has_t_stride:
return (
common_func_flow
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D)
| t_stride.scale_to_unit_system(ct.UNITS_TIDY3D)
).compose_within(
lambda els: td.FieldTimeMonitor(
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
name=name,
interval_space=els[2].flatten().tolist(),
start=els[3][0],
stop=els[3][-1],
interval=els[4],
fields=fields,
)
)
return ct.FlowSignal.FlowPending
case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
return (common_func_flow | t_range | t_stride).compose_within(
lambda els: td.FieldTimeMonitor(
name=sim_node_name,
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
interval_space=els[2].flatten().tolist(),
start=els[3][0],
stop=els[3][-1],
interval=els[4],
fields=fields,
)
)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Monitor',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'active_socket_set'},
input_sockets={
'Center',
'Size',
'Stride',
'Freqs',
't Range',
't Stride',
},
input_socket_kinds={
'Center': ct.FlowKind.Params,
'Size': ct.FlowKind.Params,
'Stride': ct.FlowKind.Params,
'Freqs': ct.FlowKind.Params,
't Range': ct.FlowKind.Params,
't Stride': ct.FlowKind.Params,
inscks_kinds={
'Center': FK.Params,
'Size': FK.Params,
'Stride': FK.Params,
'Freqs': FK.Params,
't Range': FK.Params,
't Stride': FK.Params,
},
input_sockets_optional={'Freqs', 't Range', 't Stride'},
)
def compute_params(self, props, input_sockets) -> None:
"""Lazily assembles the FieldMonitor from the input functions."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_stride = not ct.FlowSignal.check(stride)
freqs = input_sockets['Freqs']
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
if has_center and has_size and has_stride:
common_params = center | size | stride
match props['active_socket_set']:
case 'Freq Domain':
freqs = input_sockets['Freqs']
has_freqs = not ct.FlowSignal.check(freqs)
common_params = center | size | stride
match props['active_socket_set']:
case 'Freq Domain' if not FS.check(freqs):
return common_params | freqs
if has_freqs:
return common_params | freqs
case 'Time Domain':
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
has_t_range = not ct.FlowSignal.check(t_range)
has_t_stride = not ct.FlowSignal.check(t_stride)
if has_t_range and has_t_stride:
return common_params | t_range | t_stride
return ct.FlowSignal.FlowPending
case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
return common_params | t_range | t_stride
return FS.FlowPending
####################
# - Preview
####################
@events.computes_output_socket(
'Monitor',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
output_sockets={'Monitor'},
output_socket_kinds={'Monitor': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
output_params = output_sockets['Monitor']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
def compute_previews(self, props):
"""Mark the monitor as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
@ -297,32 +257,31 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
run_on_init=True,
# Loaded
managed_objs={'modifier'},
input_sockets={'Center', 'Size'},
output_sockets={'Monitor'},
output_socket_kinds={'Monitor': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
},
)
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
center = input_sockets['Center']
size = input_sockets['Size']
output_params = output_sockets['Monitor']
def on_previewable_changed(self, managed_objs, input_sockets):
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
size = events.realize_preview(input_sockets['Size'])
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_output_params = not ct.FlowSignal.check(output_params)
if has_center and has_size and has_output_params and not output_params.symbols:
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorEHField),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Size': size,
},
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorEHField),
'inputs': {
'Size': size,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
},
location=center,
)
####################

View File

@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `PowerFluxMonitorNode`."""
import itertools
import typing as typ
import bpy
@ -30,6 +33,12 @@ from ... import managed_objs, sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
ALL_3D_DIRS = {
ax + sgn for ax, sgn in itertools.product(set(ct.SimSpaceAxis), set(ct.SimAxisDir))
}
class PowerFluxMonitorNode(base.MaxwellSimNode):
@ -50,20 +59,21 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([1, 1, 1]),
default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Integer,
default_value=sp.Matrix([10, 10, 10]),
abs_min=0,
default_value=sp.ImmutableMatrix([10, 10, 10]),
abs_min=1,
),
}
input_socket_sets: typ.ClassVar = {
'Freq Domain': {
'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
active_kind=FK.Range,
physical_type=spux.PhysicalType.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
@ -73,7 +83,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
},
'Time Domain': {
't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
active_kind=FK.Range,
physical_type=spux.PhysicalType.Time,
default_unit=spu.picosecond,
default_min=0,
@ -83,16 +93,15 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
't Stride': sockets.ExprSocketDef(
mathtype=spux.MathType.Integer,
default_value=100,
abs_min=1,
),
},
}
output_socket_sets: typ.ClassVar = {
'Freq Domain': {'Freq Monitor': sockets.MaxwellMonitorSocketDef()},
'Time Domain': {'Time Monitor': sockets.MaxwellMonitorSocketDef()},
output_sockets: typ.ClassVar = {
'Monitor': sockets.MaxwellMonitorSocketDef(),
}
managed_obj_types: typ.ClassVar = {
'mesh': managed_objs.ManagedBLMesh,
'modifier': managed_objs.ManagedBLModifier,
}
@ -109,6 +118,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the properties of the node."""
# 2D Monitor
if 0 in self._compute_input('Size'):
layout.prop(self, self.blfields['direction_2d'], expand=True)
@ -125,67 +135,174 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
row.prop(self, self.blfields['include_3d_z'], expand=True)
####################
# - Event Methods: Computation
# - FlowKind.Value
####################
@events.computes_output_socket(
'Freq Monitor',
props={'sim_node_name', 'direction_2d'},
input_sockets={
'Center',
'Size',
'Stride',
'Freqs',
},
input_socket_kinds={
'Freqs': ct.FlowKind.Range,
},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
'Size': 'Tidy3DUnits',
'Freqs': 'Tidy3DUnits',
'Monitor',
kind=FK.Value,
# Loaded
output_sockets={'Monitor'},
output_socket_kinds={
'Monitor': {FK.Func, FK.Params},
},
)
def compute_freq_monitor(
self,
input_sockets: dict,
props: dict,
unit_systems: dict,
) -> td.FieldMonitor:
log.info(
'Computing FluxMonitor (name="%s") with center="%s", size="%s"',
props['sim_node_name'],
input_sockets['Center'],
input_sockets['Size'],
)
return td.FluxMonitor(
center=input_sockets['Center'],
size=input_sockets['Size'],
name=props['sim_node_name'],
interval_space=(1, 1, 1),
freqs=input_sockets['Freqs'].realize_array.values,
normal_dir=props['direction_2d'].plus_or_minus,
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Realizes the output function w/output parameters."""
value = events.realize_known(output_sockets['Monitor'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor',
kind=FK.Func,
# Loaded
props={'active_socket_set', 'sim_node_name', 'direction_2d'},
inscks_kinds={
'Center': FK.Func,
'Size': FK.Func,
'Stride': FK.Func,
'Freqs': FK.Func,
't Range': FK.Func,
't Stride': FK.Func,
},
input_sockets_optional={'Freqs', 't Range', 't Stride'},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
'Freqs': ct.UNITS_TIDY3D,
't Range': ct.UNITS_TIDY3D,
},
)
def compute_func(self, input_sockets, props) -> ct.FuncFlow: # noqa: C901
"""Compute the correct flux monitor."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
freqs = input_sockets['Freqs']
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
sim_node_name = props['sim_node_name']
direction_2d = props['direction_2d']
# 3D Flux Monitor: Computed Excluded Directions
## -> The flux is always recorded outgoing.
## -> However, one can exclude certain faces from participating.
include_3d = props['include_3d']
include_3d_x = props['include_3d_x']
include_3d_y = props['include_3d_y']
include_3d_z = props['include_3d_z']
excluded_3d = set()
if ct.SimSpaceAxis.X in include_3d:
if ct.SimAxisDir.Plus in include_3d_x:
excluded_3d.add('x+')
if ct.SimAxisDir.Minus in include_3d_x:
excluded_3d.add('x-')
if ct.SimSpaceAxis.Y in include_3d:
if ct.SimAxisDir.Plus in include_3d_y:
excluded_3d.add('y+')
if ct.SimAxisDir.Minus in include_3d_y:
excluded_3d.add('y-')
if ct.SimSpaceAxis.Z in include_3d:
if ct.SimAxisDir.Plus in include_3d_z:
excluded_3d.add('z+')
if ct.SimAxisDir.Minus in include_3d_z:
excluded_3d.add('z-')
excluded_3d = tuple(ALL_3D_DIRS - excluded_3d)
# Compute Monitor
common_func = center | size | stride
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Freq Domain' if not FS.check(freqs):
return (common_func | freqs).compose_within(
lambda els: td.FluxMonitor(
name=sim_node_name,
center=els[0],
size=els[1],
interval_space=els[2],
freqs=els[3],
normal_dir=direction_2d.plus_or_minus,
exclude_surfaces=excluded_3d,
)
)
case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
return (common_func | t_range | t_stride).compose_within(
lambda els: td.FluxTimeMonitor(
name=sim_node_name,
center=els[0],
size=els[1],
interval_space=els[2],
start=els[3].item(0),
stop=els[3].item(1),
interval=els[4],
normal_dir=direction_2d.plus_or_minus,
exclude_surfaces=excluded_3d,
)
)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Monitor',
kind=FK.Params,
# Loaded
props={'active_socket_set'},
inscks_kinds={
'Center': FK.Params,
'Size': FK.Params,
'Stride': FK.Params,
'Freqs': FK.Params,
't Range': FK.Params,
't Stride': FK.Params,
},
input_sockets_optional={'Freqs', 't Range', 't Stride'},
)
def compute_params(self, input_sockets, props) -> ct.ParamsFlow:
"""Compute the function parameters of the monitor."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
freqs = input_sockets['Freqs']
t_range = input_sockets['t Range']
t_stride = input_sockets['t Stride']
common_params = center | size | stride
# Compute Monitor
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Freq Domain' if not FS.check(freqs):
return common_params | freqs
case 'Time Domain' if not FS.check(t_range) and not FS.check(t_stride):
return common_params | t_range | t_stride
return FS.FlowPending
####################
# - Preview - Changes to Input Sockets
####################
@events.computes_output_socket(
'Time Monitor',
kind=ct.FlowKind.Previews,
'Monitor',
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews_time(self, props):
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.computes_output_socket(
'Freq Monitor',
kind=ct.FlowKind.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews_freq(self, props):
"""Mark the box structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
@ -194,27 +311,30 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
prop_name={'direction_2d'},
run_on_init=True,
# Loaded
managed_objs={'mesh', 'modifier'},
props={'direction_2d'},
input_sockets={'Center', 'Size'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
managed_objs={'modifier'},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': 'BlenderUnits',
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
},
)
def on_inputs_changed(self, managed_objs, props, input_sockets, unit_systems):
# Push Input Values to GeoNodes Modifier
def on_previewable_changed(self, managed_objs, input_sockets):
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
size = events.realize_preview(input_sockets['Size'])
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorPowerFlux),
'unit_system': unit_systems['BlenderUnits'],
'inputs': {
'Size': input_sockets['Size'],
'Direction': props['direction_2d'].true_or_false,
'Size': size,
},
},
location=input_sockets['Center'],
location=center,
)

View File

@ -20,8 +20,8 @@ import sympy as sp
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
@ -29,6 +29,11 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class PermittivityMonitorNode(base.MaxwellSimNode):
"""Provides a bounded 1D/2D/3D recording region for the diagonal of the complex-valued permittivity tensor."""
@ -43,23 +48,24 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
physical_type=PT.Length,
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([1, 1, 1]),
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0,
abs_min_closed=False,
),
'Stride': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Integer,
default_value=sp.Matrix([10, 10, 10]),
abs_min=0,
mathtype=MT.Integer,
default_value=sp.ImmutableMatrix([10, 10, 10]),
abs_min=1,
),
'Freqs': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Freq,
active_kind=FK.Range,
physical_type=PT.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
@ -67,7 +73,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
'Permittivity Monitor': sockets.MaxwellMonitorSocketDef()
'Monitor': sockets.MaxwellMonitorSocketDef(),
}
managed_obj_types: typ.ClassVar = {
@ -75,57 +81,95 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
}
####################
# - Output
# - FlowKind.Value
####################
@events.computes_output_socket(
'Permittivity Monitor',
props={'sim_node_name'},
input_sockets={
'Center',
'Size',
'Stride',
'Freqs',
},
input_socket_kinds={
'Freqs': ct.FlowKind.Range,
},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
'Size': 'Tidy3DUnits',
'Freqs': 'Tidy3DUnits',
'Monitor',
kind=FK.Value,
# Loaded
outscks_kinds={
'Monitor': {FK.Func, FK.Params},
},
)
def compute_permittivity_monitor(
self,
input_sockets: dict,
props: dict,
unit_systems: dict,
) -> td.FieldMonitor:
log.info(
'Computing PermittivityMonitor (name="%s") with center="%s", size="%s"',
props['sim_node_name'],
input_sockets['Center'],
input_sockets['Size'],
)
return td.PermittivityMonitor(
center=input_sockets['Center'],
size=input_sockets['Size'],
name=props['sim_node_name'],
interval_space=tuple(input_sockets['Stride']),
freqs=input_sockets['Freqs'].realize().values,
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Realizes the output function w/output parameters."""
value = events.realize_known(output_sockets['Monitor'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Monitor',
kind=FK.Func,
# Loaded
props={'sim_node_name'},
inscks_kinds={
'Center': FK.Func,
'Size': FK.Func,
'Stride': FK.Func,
'Freqs': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
'Freqs': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.FieldMonitor:
"""Lazily assemble the permittivity monitor from the input functions."""
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
freqs = input_sockets['Freqs']
sim_node_name = props['sim_node_name']
return (center | size | stride | freqs).compose_within(
lambda els: td.PermittivityMonitor(
name=sim_node_name,
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
interval_space=els[2].flatten().tolist(),
freqs=els[3].flatten(),
)
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Monitor',
kind=FK.Params,
# Loaded
inscks_kinds={
'Center': FK.Params,
'Size': FK.Params,
'Stride': FK.Params,
'Freqs': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.FieldMonitor:
center = input_sockets['Center']
size = input_sockets['Size']
stride = input_sockets['Stride']
freqs = input_sockets['Freqs']
return center | size | stride | freqs
####################
# - Preview
####################
@events.computes_output_socket(
'Permittivity Monitor',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews_freq(self, props):
"""Mark the monitor as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
@ -134,29 +178,30 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
run_on_init=True,
# Loaded
managed_objs={'modifier'},
input_sockets={'Center', 'Size'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': 'BlenderUnits',
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
},
)
def on_inputs_changed(
self,
managed_objs: dict,
input_sockets: dict,
unit_systems: dict,
):
def on_previewable_changed(self, managed_objs, input_sockets):
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
size = events.realize_preview(input_sockets['Size'])
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.MonitorPermittivity),
'unit_system': unit_systems['BlenderUnits'],
'inputs': {
'Size': input_sockets['Size'],
'Size': size,
},
},
location=input_sockets['Center'],
location=center,
)

View File

@ -104,6 +104,7 @@ class ViewerNode(base.MaxwellSimNode):
@events.on_value_changed(
# Trugger
socket_name='Any',
prop_name='auto_expr',
# Loaded
props={'auto_expr', 'console_print_kind'},
)
@ -114,11 +115,7 @@ class ViewerNode(base.MaxwellSimNode):
This **does not** call the flow twice, as `self._compute_input()` will be cached the first time.
"""
# Invalidate PreviewsFlow
setattr(
self,
'input_' + ct.FlowKind.Previews.property_name,
bl_cache.Signal.InvalidateCache,
)
self.input_previews = bl_cache.Signal.InvalidateCache
# Invalidate PreviewsFlow
if props['auto_expr']:
@ -128,35 +125,35 @@ class ViewerNode(base.MaxwellSimNode):
bl_cache.Signal.InvalidateCache,
)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_capabilities(self) -> ct.CapabilitiesFlow | None:
return self.get_flow(ct.FlowKind.Capabilities)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_previews(self) -> ct.PreviewsFlow | None:
return self.get_flow(ct.FlowKind.Previews, always_load=True)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_value(self) -> ct.ValueFlow | None:
return self.get_flow(ct.FlowKind.Value)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_array(self) -> ct.ArrayFlow | None:
return self.get_flow(ct.FlowKind.Array)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_lazy_range(self) -> ct.RangeFlow | None:
return self.get_flow(ct.FlowKind.Range)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_lazy_func(self) -> ct.FuncFlow | None:
return self.get_flow(ct.FlowKind.Func)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_params(self) -> ct.ParamsFlow | None:
return self.get_flow(ct.FlowKind.Params)
@bl_cache.cached_bl_property(depends_on={'auto_expr'})
@bl_cache.cached_bl_property()
def input_info(self) -> ct.InfoFlow | None:
return self.get_flow(ct.FlowKind.Info)
@ -214,6 +211,11 @@ class ViewerNode(base.MaxwellSimNode):
if key != 'type'
]
# Parse Straight String
if isinstance(value, str):
lines = value.split('\n')
return [[line] for line in lines]
return None
####################
@ -305,6 +307,9 @@ class ViewerNode(base.MaxwellSimNode):
log.info('Printing to Console')
if isinstance(flow, spux.SympyType):
console.print(sp.pretty(flow, use_unicode=True))
elif isinstance(flow, ct.ArrayFlow):
console.print(flow)
console.print(flow.values)
else:
console.print(flow)

View File

@ -14,13 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `Tidy3DWebExporter`."""
import typing as typ
import bpy
import tidy3d as td
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from .... import contracts as ct
from .... import sockets
@ -28,47 +30,30 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
SimArray: typ.TypeAlias = dict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
]
SimArrayInfo: typ.TypeAlias = dict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
]
####################
# - Operators
####################
class RecomputeSimInfo(bpy.types.Operator):
bl_idname = ct.OperatorType.NodeRecomputeSimInfo
bl_label = 'Recompute Tidy3D Sim Info'
bl_description = 'Recompute info for any currently attached sim info'
@classmethod
def poll(cls, context):
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebExporter
# Check Sim is Available (aka. uploadeable)
and context.node.sim_info_available
and context.node.sim_info_invalidated
)
def execute(self, context):
node = context.node
# Rehydrate the Cache
node.total_monitor_data = bl_cache.Signal.InvalidateCache
node.is_sim_uploadable = bl_cache.Signal.InvalidateCache
# Remove the Invalidation Marker
## -> This is OK, since we manually guaranteed that it's available.
node.sim_info_invalidated = False
return {'FINISHED'}
class UploadSimulation(bpy.types.Operator):
"""Upload the simulation embedded in the `Tidy3DWebExpoerter`."""
bl_idname = ct.OperatorType.NodeUploadSimulation
bl_label = 'Upload Tidy3D Simulation'
bl_description = 'Upload the attached (locked) simulation, such that it is ready to run on the Tidy3D cloud'
@classmethod
def poll(cls, context):
"""Allow running whenever there are simulations to upload."""
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
@ -77,41 +62,80 @@ class UploadSimulation(bpy.types.Operator):
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebExporter
# Check Sim is Available (aka. uploadeable)
and context.node.is_sim_uploadable
and context.node.uploaded_task_id == ''
and context.node.sims
and context.node.sims_uploadable
and not context.node.uploaded_task_ids
)
def execute(self, context):
"""Upload either a single or a batch of simulations.
Later retrieval of realized parameter points exist in the form of a realized serializable dictionary attached to the `.attrs` field of any `td.Simulation` object.
"""
node = context.node
cloud_task = tdcloud.TidyCloudTasks.mk_task(
task_name=node.new_cloud_task.task_name,
cloud_folder=node.new_cloud_task.cloud_folder,
sim=node.sim,
verbose=True,
)
node.uploaded_task_id = cloud_task.task_id
if node.base_cloud_task is not None:
base_task_name = node.base_cloud_task.task_name
base_task_folder = node.base_cloud_task.cloud_folder
else:
self.report({'ERROR'}, 'No base cloud task name')
return {'FINISHED'}
if node.active_socket_set == 'Single':
if len(list(node.sims.values())) == 1:
sim = next(iter(node.sims.values()))
else:
self.report({'ERROR'}, '>1 sims for "Single"-mode sim exporter.')
return {'FINISHED'}
cloud_task = tdcloud.TidyCloudTasks.mk_task(
task_name=base_task_name,
cloud_folder=base_task_folder,
sim=sim,
verbose=True,
)
node.uploaded_task_ids = (cloud_task.task_id,)
if node.active_socket_set == 'Batch':
cloud_tasks = [
tdcloud.TidyCloudTasks.mk_task(
task_name=base_task_name + f'_{i}',
cloud_folder=base_task_folder,
sim=sim,
verbose=True,
)
for i, sim in enumerate(node.sims.values())
]
node.uploaded_task_ids = tuple(
[cloud_task.task_id for cloud_task in cloud_tasks]
)
return {'FINISHED'}
class ReleaseUploadedTask(bpy.types.Operator):
"""Release the uploaded simulation embedded in the `Tidy3DWebExpoerter`."""
bl_idname = ct.OperatorType.NodeReleaseUploadedTask
bl_label = 'Release Tracked Tidy3D Cloud Task'
bl_description = 'Release the currently tracked simulation task'
@classmethod
def poll(cls, context):
"""Allow running whenever a particular FDTDSim node is tracking uploaded simulations."""
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.Tidy3DWebExporter
# Check Sim is Available (aka. uploadeable)
and context.node.uploaded_task_id != ''
and context.node.uploaded_task_ids
)
def execute(self, context):
"""Invalidate the `.sims` property, triggering reevaluation of all downstream information about the simulation."""
node = context.node
node.uploaded_task_id = ''
node.uploaded_task_ids = ()
return {'FINISHED'}
@ -119,303 +143,357 @@ class ReleaseUploadedTask(bpy.types.Operator):
# - Node
####################
class Tidy3DWebExporterNode(base.MaxwellSimNode):
"""Export a simulation to the Tidy3D cloud service, where it can be queried and run."""
node_type = ct.NodeType.Tidy3DWebExporter
bl_label = 'Tidy3D Web Exporter'
input_sockets: typ.ClassVar = {
'Sim': sockets.MaxwellFDTDSimSocketDef(),
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=False,
),
input_socket_sets: typ.ClassVar = {
'Single': {
'Sim': sockets.MaxwellFDTDSimSocketDef(),
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Value,
should_exist=False,
),
},
'Batch': {
'Sims': sockets.MaxwellFDTDSimSocketDef(
active_kind=FK.Array,
),
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Value,
should_exist=False,
),
},
}
output_sockets: typ.ClassVar = {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=True,
),
output_socket_sets: typ.ClassVar = {
'Single': {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Value,
should_exist=True,
),
},
'Batch': {
'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Array,
should_exist=True,
),
},
}
####################
# - Properties
####################
sim_info_available: bool = bl_cache.BLField(False)
sim_info_invalidated: bool = bl_cache.BLField(False)
uploaded_task_id: str = bl_cache.BLField('')
uploaded_task_ids: tuple[str, ...] = bl_cache.BLField(())
####################
# - Computed - Sim
# - Properties: Socket -> Props
####################
@bl_cache.cached_bl_property()
def sim(self) -> td.Simulation | None:
sim = self._compute_input('Sim')
has_sim = not ct.FlowSignal.check(sim)
@events.on_value_changed(
socket_name={'Sim': {FK.Value, FK.Array}},
)
def on_sims_changed(self) -> None:
"""Regenerate the simulation property on changes."""
self.sims = bl_cache.Signal.InvalidateCache
@events.on_value_changed(
socket_name={'Cloud Task': {FK.Value}},
)
def on_base_cloud_task_changed(self) -> None:
"""Regenerate the cloud task property on changes."""
self.base_cloud_task = bl_cache.Signal.InvalidateCache
####################
# - Properties: Socket Alias
####################
@bl_cache.cached_bl_property(depends_on={'active_socket_set'})
def sims(self) -> SimArray | None:
"""The complete description of all simulation data input."""
if self.active_socket_set == 'Single':
sim_data_value = self._compute_input('Sim', kind=FK.Value)
has_sim_data_value = not FS.check(sim_data_value)
if has_sim_data_value:
return {(): sim_data_value}
elif self.active_socket_set == 'Batch':
sim_data_array = self._compute_input('Sims', kind=FK.Array)
has_sim_data_array = not FS.check(sim_data_array)
if has_sim_data_array:
return sim_data_array
if has_sim:
return sim
return None
@bl_cache.cached_bl_property()
def total_monitor_data(self) -> float | None:
if self.sim is not None:
return sum(self.sim.monitors_data_size.values())
return None
####################
# - Computed - New Cloud Task
####################
@property
def new_cloud_task(self) -> ct.NewSimCloudTask | None:
"""Retrieve the current new cloud task from the input socket.
If one can't be loaded, return None.
"""
new_cloud_task = self._compute_input(
def base_cloud_task(self) -> ct.NewSimCloudTask | None:
"""The complete description of all simulation input objects."""
base_cloud_task = self._compute_input(
'Cloud Task',
kind=ct.FlowKind.Value,
)
has_new_cloud_task = not ct.FlowSignal.check(new_cloud_task)
if has_new_cloud_task and new_cloud_task.task_name != '':
return new_cloud_task
has_base_cloud_task = not FS.check(base_cloud_task)
if has_base_cloud_task and base_cloud_task.task_name != '':
return base_cloud_task
return None
####################
# - Computed - Uploaded Cloud Task
# - Properties: Simulations
####################
@property
def uploaded_task(self) -> tdcloud.CloudTask | None:
"""Retrieve the uploaded cloud task.
@bl_cache.cached_bl_property(depends_on={'sims'})
def sims_valid(
self,
) -> (
dict[tuple[tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...]], bool] | None
):
"""Whether all sims are valid."""
if self.sims is not None:
validity = {}
for k, sim in self.sims.items(): # noqa: B007
try:
pass ## TODO: VERY slow, batch checking is infeasible
# sim.validate_pre_upload(source_required=True)
except td.exceptions.SetupError:
validity[k] = False
else:
validity[k] = True
If one can't be loaded, return None.
"""
has_uploaded_task = self.uploaded_task_id != ''
has_new_cloud_task = self.new_cloud_task is not None
if has_uploaded_task and has_new_cloud_task:
return tdcloud.TidyCloudTasks.tasks(self.new_cloud_task.cloud_folder).get(
self.uploaded_task_id
)
return validity
return None
@property
def uploaded_task_info(self) -> tdcloud.CloudTask | None:
"""Retrieve the uploaded cloud task.
If one can't be loaded, return None.
"""
uploaded_task = self.uploaded_task
if uploaded_task is not None:
return tdcloud.TidyCloudTasks.task_info(self.uploaded_task_id)
####################
# - Properties: Tasks
####################
@bl_cache.cached_bl_property(depends_on={'uploaded_task_ids'})
def uploaded_task_infos(self) -> list[tdcloud.CloudTask | None] | None:
"""Retrieve information about the uploaded cloud tasks."""
if self.uploaded_task_ids:
return [
tdcloud.TidyCloudTasks.task_info(task_id)
for task_id in self.uploaded_task_ids
]
return None
@bl_cache.cached_bl_property()
def uploaded_est_cost(self) -> float | None:
task_info = self.uploaded_task_info
if task_info is not None:
est_cost = task_info.cost_est()
if est_cost is not None:
return est_cost
@bl_cache.cached_bl_property(depends_on={'uploaded_task_infos'})
def est_costs(self) -> list[float | None] | None:
"""Estimate the FlexCredit cost of each uploaded task."""
if self.uploaded_task_infos is not None and all(
task_info is not None for task_info in self.uploaded_task_infos
):
return [task_info.cost_est() for task_info in self.uploaded_task_infos]
return None
@bl_cache.cached_bl_property(depends_on={'est_costs'})
def total_est_cost(self) -> list[float | None] | None:
"""Estimate the total FlexCredits cost of all uploaded tasks."""
if self.est_costs is not None and all(
est_cost is not None for est_cost in self.est_costs
):
return sum(self.est_costs)
return None
@bl_cache.cached_bl_property(depends_on={'uploaded_task_infos'})
def real_costs(self) -> list[float | None] | None:
"""Estimate the FlexCredit cost of each uploaded task."""
if self.uploaded_task_infos is not None and all(
task_info is not None for task_info in self.uploaded_task_infos
):
return [task_info.cost_real for task_info in self.uploaded_task_infos]
return None
@bl_cache.cached_bl_property(depends_on={'real_costs'})
def total_real_cost(self) -> list[float | None] | None:
"""Estimate the total FlexCredits cost of all uploaded tasks."""
if self.real_costs is not None and all(
real_cost is not None for real_cost in self.real_costs
):
return sum(self.real_costs)
return None
####################
# - Computed - Combined
####################
@bl_cache.cached_bl_property()
def is_sim_uploadable(self) -> bool:
if (
self.sim is not None
and self.uploaded_task_id == ''
and self.new_cloud_task is not None
and self.new_cloud_task.task_name != ''
):
try:
self.sim.validate_pre_upload(source_required=True)
except:
log.exception()
return False
else:
return True
return False
@bl_cache.cached_bl_property(depends_on={'sims_valid'})
def sims_uploadable(self) -> bool:
"""Whether all simulations can be uploaded."""
return self.sims_valid is not None and all(self.sims_valid.values())
####################
# - UI
####################
def draw_operators(self, context, layout):
@bl_cache.cached_bl_property(
depends_on={
'uploaded_task_infos',
'total_est_cost',
'total_real_cost',
}
)
def task_labels(self) -> SimArrayInfo | None:
"""Pre-processed labels for efficient drawing of task info."""
if self.uploaded_task_infos is not None and all(
task_info is not None for task_info in self.uploaded_task_infos
):
return {
task_info.task_id: [
f'Task: {task_info.task_name}',
('Status', task_info.status),
(
'Est.',
(
f'{self.total_est_cost:.2f} creds'
if self.total_est_cost is not None
else 'TBD...'
),
),
(
'Real',
(
f'{self.total_real_cost:.2f} creds'
if self.total_real_cost is not None
else 'TBD'
),
),
]
for task_info in self.uploaded_task_infos
}
return None
def draw_operators(self, _, layout):
"""Draw operators for uploading/releasing simulations."""
# Row: Upload Sim Buttons
row = layout.row(align=True)
row.operator(
ct.OperatorType.NodeUploadSimulation,
text='Upload',
)
if self.uploaded_task_id:
if self.uploaded_task_ids:
row.operator(
ct.OperatorType.NodeReleaseUploadedTask,
icon='LOOP_BACK',
text='',
)
def draw_info(self, context, layout):
def draw_info(self, _, layout):
"""Draw information relevant for simulation uploading."""
# Connection Info
auth_icon = 'CHECKBOX_HLT' if tdcloud.IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
conn_icon = 'CHECKBOX_HLT' if tdcloud.IS_ONLINE else 'CHECKBOX_DEHLT'
row = layout.row()
box = layout.box()
# Cloud Info
row = box.row()
row.alignment = 'CENTER'
row.label(text='Cloud Status')
box = layout.box()
split = box.split(factor=0.85)
## Split: Left Column
col = split.column(align=False)
col.label(text='Authed')
col.label(text='Connected')
## Split: Right Column
col = split.column(align=False)
col.label(icon=auth_icon)
col.label(icon=conn_icon)
# Simulation Info
if self.sim is not None:
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Sim Info')
box = layout.box()
if self.task_labels is not None:
for labels in self.task_labels.values():
row = layout.row(align=True)
row.alignment = 'CENTER'
row.label(text='Task Status')
if self.sim_info_invalidated:
box.operator(ct.OperatorType.NodeRecomputeSimInfo, text='Regenerate')
else:
split = box.split(factor=0.5)
for el in labels:
# Header
if isinstance(el, str):
box = layout.box()
row = box.row(align=True)
row.alignment = 'CENTER'
row.label(text=el)
## Split: Left Column
col = split.column(align=False)
col.label(text='𝝨 Data')
split = box.split(factor=0.4)
col_l = split.column(align=True)
col_r = split.column(align=True)
## Split: Right Column
col = split.column(align=False)
col.alignment = 'RIGHT'
col.label(text=f'{self.total_monitor_data / 1_000_000:.2f}MB')
# Label Pair
elif isinstance(el, tuple):
col_l.label(text=el[0])
col_r.label(text=el[1])
if self.uploaded_task_info is not None:
# Uploaded Task Information
box = layout.box()
split = box.split(factor=0.6)
else:
raise TypeError
## Split: Left Column
col = split.column(align=False)
col.label(text='Status')
col.label(text='Est. Cost')
col.label(text='Real Cost')
## Split: Right Column
cost_est = (
f'{self.uploaded_est_cost:.2f}'
if self.uploaded_est_cost is not None
else 'TBD'
)
cost_real = (
f'{self.uploaded_task_info.cost_real:.2f}'
if self.uploaded_task_info.cost_real is not None
else 'TBD'
)
col = split.column(align=False)
col.alignment = 'RIGHT'
col.label(text=self.uploaded_task_info.status.capitalize())
col.label(text=f'{cost_est} creds')
col.label(text=f'{cost_real} creds')
# Connection Information
break
####################
# - Events
####################
@events.on_value_changed(
socket_name='Sim',
run_on_init=True,
props={'sim_info_available', 'sim_info_invalidated'},
)
def on_sim_changed(self, props) -> None:
# Sim Linked | First Value Change
if self.inputs['Sim'].is_linked and not props['sim_info_available']:
log.debug('%s: First Change; Mark Sim Info Available', self.sim_node_name)
self.sim = bl_cache.Signal.InvalidateCache
self.total_monitor_data = bl_cache.Signal.InvalidateCache
self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
self.sim_info_available = True
# Sim Linked | Second Value Change
if (
self.inputs['Sim'].is_linked
and props['sim_info_available']
and not props['sim_info_invalidated']
):
log.debug('%s: Second Change; Mark Sim Info Invalided', self.sim_node_name)
self.sim_info_invalidated = True
# Sim Linked | Nth Time
## -> Danger of infinite expensive recompute of the sim every change.
## -> Instead, user must manually set "available & not invalidated".
## -> The UI should explain that the caches are dry.
## -> The UI should also provide such a "hydration" button.
# Sim Not Linked
## -> If the sim is straight-up not available, cache needs changing.
## -> Luckily, since we know there's no sim, invalidation is cheap.
## -> Ends up being a "circuit breaker" for sim_info_invalidated.
elif not self.inputs['Sim'].is_linked:
log.debug(
'%s: Unlinked; Short Circuit the Sim Info Cache', self.sim_node_name
)
self.sim = bl_cache.Signal.InvalidateCache
self.total_monitor_data = bl_cache.Signal.InvalidateCache
self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
self.sim_info_available = False
self.sim_info_invalidated = False
@events.on_value_changed(
socket_name='Cloud Task',
run_on_init=True,
)
def on_new_cloud_task_changed(self):
self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
@events.on_value_changed(
# Trigger
prop_name='uploaded_task_id',
run_on_init=True,
prop_name='uploaded_task_ids',
# Loaded
props={'uploaded_task_id'},
props={'uploaded_task_ids'},
)
def on_uploaded_task_changed(self, props):
log.debug('Uploaded Task Changed')
self.is_sim_uploadable = bl_cache.Signal.InvalidateCache
"""When uploaded tasks change, take appropriate action.
if props['uploaded_task_id'] != '':
- Enable/Disable Lock: To prevent node-tree modifications that would invalidate the validity of uploaded tasks.
- Ensure Est Cost: Repeatedly try to load the estimated cost of all tasks, until all are available.
"""
uploaded_task_ids = props['uploaded_task_ids']
# Lock
if uploaded_task_ids:
self.trigger_event(ct.FlowEvent.EnableLock)
self.locked = False
# Force Computation of Estimated Cost
## -> Try recomputing the estimated cost of all tasks.
## -> Once all are non-None, stop.
max_tries = 20
for _ in range(max_tries):
self.est_costs = bl_cache.Signal.InvalidateCache
if self.total_est_cost is not None:
break
else:
self.trigger_event(ct.FlowEvent.DisableLock)
max_tries = 10
for _ in range(max_tries):
self.uploaded_est_cost = bl_cache.Signal.InvalidateCache
if self.uploaded_est_cost is not None:
break
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Cloud Task',
props={'uploaded_task_id', 'uploaded_task'},
kind=ct.FlowKind.Value,
# Loaded
props={'uploaded_task_ids'},
)
def compute_cloud_task(self, props) -> tdcloud.CloudTask | None:
if props['uploaded_task_id'] != '':
return props['uploaded_task']
"""A single uploaded cloud task, when there only is one."""
uploaded_task_ids = props['uploaded_task_ids']
return ct.FlowSignal.FlowPending
if uploaded_task_ids is not None and len(uploaded_task_ids) == 1:
cloud_task = tdcloud.TidyCloudTasks.task(uploaded_task_ids[0])
if cloud_task is not None:
return cloud_task
return FS.FlowPending
####################
# - FlowKind.Array
####################
@events.computes_output_socket(
'Cloud Tasks',
kind=ct.FlowKind.Array,
# Loaded
props={'uploaded_task_ids'},
)
def compute_cloud_tasks(self, props) -> tdcloud.CloudTask | None:
"""All uploaded cloud task, when there are more than one."""
uploaded_task_ids = props['uploaded_task_ids']
if len(uploaded_task_ids) > 1:
cloud_tasks = [
tdcloud.TidyCloudTasks.task(task_id) for task_id in uploaded_task_ids
]
if all(cloud_task is not None for cloud_task in cloud_tasks):
return cloud_tasks
return FS.FlowPending
####################
@ -425,7 +503,6 @@ BL_REGISTER = [
UploadSimulation,
ReleaseUploadedTask,
Tidy3DWebExporterNode,
RecomputeSimInfo,
]
BL_NODES = {
ct.NodeType.Tidy3DWebExporter: (ct.NodeCategory.MAXWELLSIM_OUTPUTS_WEBEXPORTERS)

View File

@ -14,21 +14,28 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# from . import sim_grid
# from . import sim_grid_axes
from . import combine, fdtd_sim, sim_domain
from . import (
bound_cond_faces,
bound_conds,
fdtd_sim,
sim_domain,
sim_grid,
sim_grid_axes,
)
BL_REGISTER = [
*combine.BL_REGISTER,
*sim_domain.BL_REGISTER,
# *sim_grid.BL_REGISTER,
# *sim_grid_axes.BL_REGISTER,
*fdtd_sim.BL_REGISTER,
*sim_domain.BL_REGISTER,
*bound_conds.BL_REGISTER,
*bound_cond_faces.BL_REGISTER,
*sim_grid.BL_REGISTER,
*sim_grid_axes.BL_REGISTER,
]
BL_NODES = {
**combine.BL_NODES,
**sim_domain.BL_NODES,
# **sim_grid.BL_NODES,
# **sim_grid_axes.BL_NODES,
**fdtd_sim.BL_NODES,
**sim_domain.BL_NODES,
**bound_conds.BL_NODES,
**bound_cond_faces.BL_NODES,
**sim_grid.BL_NODES,
**sim_grid_axes.BL_NODES,
}

View File

@ -22,8 +22,8 @@ import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
r"""A boundary condition that generically (adiabatically) absorbs outgoing energy, by gradually ramping up the strength of the conductor over many layers, until a final PEC layer.
@ -78,7 +81,7 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
),
'σ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
default_value=sp.Matrix([0, 1.5]),
default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
},
@ -91,6 +94,11 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's reported info.
Parameters:
layout: UI target for drawing.
"""
if self.active_socket_set == 'Full':
box = layout.box()
row = box.row()
@ -115,118 +123,71 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
@events.computes_output_socket(
'BC',
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
'σ Order',
'σ Range',
outscks_kinds={
'BC': {FK.Func, FK.Params},
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
)
def compute_bc_value(self, props, input_sockets, output_sockets) -> td.Absorber:
def compute_bc_value(self, output_sockets) -> td.Absorber | FS:
r"""Computes the adiabatic absorber boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
output_params = output_sockets['BC']
layers = input_sockets['Layers']
has_output_params = not ct.FlowSignal.check(output_params)
has_layers = not ct.FlowSignal.check(layers)
active_socket_set = props['active_socket_set']
if has_layers and has_output_params and not output_params.symbols:
# Simple PML
if active_socket_set == 'Simple':
return td.Absorber(num_layers=layers)
# Full PML
sig_order = input_sockets['σ Order']
sig_range = input_sockets['σ Range']
has_sig_order = not ct.FlowSignal.check(sig_order)
has_sig_range = not ct.FlowSignal.check(sig_range)
if has_sig_order and has_sig_range:
return td.Absorber(
num_layers=layers,
parameters=td.AbsorberParams(
sigma_order=sig_order,
sigma_min=sig_range[0],
sigma_max=sig_range[1],
),
)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['BC'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
'σ Order',
'σ Range',
inscks_kinds={
'Layers': FK.Func,
'σ Order': FK.Func,
'σ Range': FK.Func,
},
input_socket_kinds={
'Layers': ct.FlowKind.Func,
'σ Order': ct.FlowKind.Func,
'σ Range': ct.FlowKind.Func,
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
input_sockets_optional={'σ Order', 'σ Range'},
)
def compute_bc_func(self, props, input_sockets, output_sockets) -> td.Absorber:
def compute_bc_func(self, props, input_sockets) -> td.Absorber:
r"""Computes the adiabatic absorber boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
layers = input_sockets['Layers']
has_layers = not ct.FlowSignal.check(layers)
active_socket_set = props['active_socket_set']
if has_layers:
# Simple PML
if active_socket_set == 'Simple':
match active_socket_set:
case 'Simple':
return layers.compose_within(
enclosing_func=lambda _layers: td.Absorber(num_layers=_layers),
supports_jax=False,
)
case 'Full':
sig_order = input_sockets['σ Order']
sig_range = input_sockets['σ Range']
# Full PML
sig_order = input_sockets['σ Order']
sig_range = input_sockets['σ Range']
has_sig_order = not FS.check(sig_order)
has_sig_range = not FS.check(sig_range)
has_sig_order = not ct.FlowSignal.check(sig_order)
has_sig_range = not ct.FlowSignal.check(sig_range)
if has_sig_order and has_sig_range:
return (layers | sig_order | sig_range).compose_within(
enclosing_func=lambda els: td.Absorber(
num_layers=els[0],
parameters=td.AbsorberParams(
sigma_order=els[1],
sigma_min=els[2][0],
sigma_max=els[2][1],
if has_sig_order and has_sig_range:
return (layers | sig_order | sig_range).compose_within(
enclosing_func=lambda els: td.Absorber(
num_layers=els[0],
parameters=td.AbsorberParams(
sigma_order=els[1],
sigma_min=els[2].item(0),
sigma_max=els[2].item(1),
),
),
),
supports_jax=False,
)
supports_jax=False,
)
return ct.FlowSignal.FlowPending
####################
@ -234,44 +195,35 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
'σ Order',
'σ Range',
},
input_socket_kinds={
'Layers': ct.FlowKind.Params,
'σ Order': ct.FlowKind.Params,
'σ Range': ct.FlowKind.Params,
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
inscks_kinds={
'Layers': FK.Params,
'σ Order': FK.Params,
'σ Range': FK.Params,
},
input_sockets_optional={'σ Order', 'σ Range'},
)
def compute_params(self, props, input_sockets) -> td.Box:
"""Aggregate the function parameters needed by the absorbing BC."""
layers = input_sockets['Layers']
has_layers = not ct.FlowSignal.check(layers)
active_socket_set = props['active_socket_set']
if has_layers:
# Simple PML
if active_socket_set == 'Simple':
match active_socket_set:
case 'Simple':
return layers
# Full PML
sig_order = input_sockets['σ Order']
sig_range = input_sockets['σ Range']
case 'Full':
sig_order = input_sockets['σ Order']
sig_range = input_sockets['σ Range']
has_sig_order = not ct.FlowSignal.check(sig_order)
has_sig_range = not ct.FlowSignal.check(sig_range)
has_sig_order = not FS.check(sig_order)
has_sig_range = not FS.check(sig_range)
if has_sig_order and has_sig_range:
return layers | sig_order | sig_range
if has_sig_order and has_sig_range:
return layers | sig_order | sig_range
return ct.FlowSignal.FlowPending
@ -281,4 +233,6 @@ class AdiabAbsorbBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
AdiabAbsorbBoundCondNode,
]
BL_NODES = {ct.NodeType.AdiabAbsorbBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
BL_NODES = {
ct.NodeType.AdiabAbsorbBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)
}

View File

@ -29,6 +29,9 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class BlochBoundCondNode(base.MaxwellSimNode):
r"""A boundary condition that declares an "infinitely repeating" window, by applying Bloch's theorem to accurately describe how a boundary would behave if it were interacting with an infinitely repeating simulation structure.
@ -117,7 +120,7 @@ class BlochBoundCondNode(base.MaxwellSimNode):
input_socket_sets: typ.ClassVar = {
'Naive': {},
'Source-Derived': {
'Angled Source': sockets.MaxwellSourceSocketDef(),
'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
## TODO: Constrain to gaussian beam, plane wafe, and tfsf
'Sim Domain': sockets.MaxwellSimDomainSocketDef(),
},
@ -138,10 +141,20 @@ class BlochBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interface of the node's properties.
Parameters:
layout: UI target for drawing.
"""
if self.active_socket_set == 'Source-Derived':
layout.prop(self, self.blfields['valid_sim_axis'], expand=True)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's reported info.
Parameters:
layout: UI target for drawing.
"""
if self.active_socket_set == 'Manual':
box = layout.box()
row = box.row()
@ -204,7 +217,7 @@ class BlochBoundCondNode(base.MaxwellSimNode):
'Bloch Vector': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
output_socket_kinds={'BC': FK.Params},
)
def compute_value(
self, props, input_sockets, output_sockets
@ -217,9 +230,9 @@ class BlochBoundCondNode(base.MaxwellSimNode):
- **Manual**: Set the Bloch vector to the user-specified value.
"""
output_params = output_sockets['BC']
has_output_params = not ct.FlowSignal.check(output_params)
has_output_params = not FS.check(output_params)
if not has_output_params or (has_output_params and output_params.symbols):
return ct.FlowSignal.FlowPending
return FS.FlowPending
active_socket_set = props['active_socket_set']
match active_socket_set:
@ -230,8 +243,8 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
has_angled_source = not ct.FlowSignal.check(angled_source)
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_angled_source = not FS.check(angled_source)
has_sim_domain = not FS.check(sim_domain)
if has_angled_source and has_sim_domain:
valid_sim_axis = props['valid_sim_axis']
@ -241,22 +254,22 @@ class BlochBoundCondNode(base.MaxwellSimNode):
axis=valid_sim_axis.axis,
medium=sim_domain['medium'],
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return td.BlochBoundary(bloch_vec=bloch_vector)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'active_socket_set', 'valid_sim_axis'},
input_sockets={
@ -265,29 +278,18 @@ class BlochBoundCondNode(base.MaxwellSimNode):
'Bloch Vector',
},
input_socket_kinds={
'Angled Source': ct.FlowKind.Func,
'Sim Domain': ct.FlowKind.Func,
'Bloch Vector': ct.FlowKind.Func,
'Angled Source': FK.Func,
'Sim Domain': FK.Func,
'Bloch Vector': FK.Func,
},
input_sockets_optional={
'Angled Source': True,
'Sim Domain': True,
'Bloch Vector': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
input_sockets_optional={'Angled Source', 'Sim Domain', 'Bloch Vector'},
)
def compute_bc_func(self, props, input_sockets, output_sockets) -> td.Absorber:
r"""Computes the adiabatic absorber boundary condition based on the active socket set.
def compute_bc_func(self, props, input_sockets) -> td.Absorber:
r"""Computes the bloch boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the absorber parameters (apart from number of layers).
- **Full**: Use the user-defined $\sigma$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
output_params = output_sockets['BC']
has_output_params = not ct.FlowSignal.check(output_params)
if not has_output_params:
return ct.FlowSignal.FlowPending
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Naive':
@ -300,8 +302,8 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
has_angled_source = not ct.FlowSignal.check(angled_source)
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_angled_source = not FS.check(angled_source)
has_sim_domain = not FS.check(sim_domain)
if has_angled_source and has_sim_domain:
valid_sim_axis = props['valid_sim_axis']
@ -314,11 +316,11 @@ class BlochBoundCondNode(base.MaxwellSimNode):
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return bloch_vector.compose_within(
@ -327,33 +329,25 @@ class BlochBoundCondNode(base.MaxwellSimNode):
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'active_socket_set'},
input_sockets={
'Angled Source',
'Sim Domain',
'Bloch Vector',
},
input_socket_kinds={
'Angled Source': ct.FlowKind.Params,
'Sim Domain': ct.FlowKind.Params,
'Bloch Vector': ct.FlowKind.Params,
},
input_sockets_optional={
'Angled Source': True,
'Sim Domain': True,
'Bloch Vector': True,
inscks_kinds={
'Angled Source': FK.Params,
'Sim Domain': FK.Params,
'Bloch Vector': FK.Params,
},
input_sockets_optional={'Angled Source', 'Sim Domain', 'Bloch Vector'},
)
def compute_bc_params(self, props, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_bc_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Aggregate the function parameters needed by the bloch BC."""
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Naive':
@ -363,20 +357,20 @@ class BlochBoundCondNode(base.MaxwellSimNode):
angled_source = input_sockets['Angled Source']
sim_domain = input_sockets['Sim Domain']
has_angled_source = not ct.FlowSignal.check(angled_source)
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_angled_source = not FS.check(angled_source)
has_sim_domain = not FS.check(sim_domain)
if has_sim_domain and has_angled_source:
return angled_source | sim_domain
return ct.FlowSignal.FlowPending
return FS.FlowPending
case 'Manual':
bloch_vector = input_sockets['Bloch Vector']
has_bloch_vector = not ct.FlowSignal.check(bloch_vector)
has_bloch_vector = not FS.check(bloch_vector)
if has_bloch_vector:
return bloch_vector
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
@ -385,4 +379,6 @@ class BlochBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
BlochBoundCondNode,
]
BL_NODES = {ct.NodeType.BlochBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
BL_NODES = {
ct.NodeType.BlochBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)
}

View File

@ -22,8 +22,8 @@ import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import sockets
@ -31,6 +31,9 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class PMLBoundCondNode(base.MaxwellSimNode):
r"""A "Perfectly Matched Layer" boundary condition, which is a theoretical medium that attempts to _perfectly_ absorb all outgoing waves, so as to represent "infinite space" in FDTD simulations.
@ -82,7 +85,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'σ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
default_value=sp.Matrix([0, 1.5]),
default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
'κ Order': sockets.ExprSocketDef(
@ -94,7 +97,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'κ Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
default_value=sp.Matrix([0, 1.5]),
default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
'α Order': sockets.ExprSocketDef(
@ -106,7 +109,7 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Range': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
default_value=sp.Matrix([0, 1.5]),
default_value=sp.ImmutableMatrix([0, 1.5]),
abs_min=0,
),
},
@ -119,6 +122,11 @@ class PMLBoundCondNode(base.MaxwellSimNode):
# - UI
####################
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Draw the user interfaces of the node's reported info.
Parameters:
layout: UI target for drawing.
"""
if self.active_socket_set == 'Full':
box = layout.box()
row = box.row()
@ -144,11 +152,19 @@ class PMLBoundCondNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
inscks_kinds={
'Layers': FK.Value,
'σ Order': FK.Value,
'σ Range': FK.Value,
'κ Order': FK.Value,
'κ Range': FK.Value,
'α Order': FK.Value,
'α Range': FK.Value,
},
input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@ -156,16 +172,8 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
'κ Order': True,
'κ Range': True,
'α Order': True,
'α Range': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
output_socket_kinds={'BC': FK.Params},
)
def compute_pml_value(self, props, input_sockets, output_sockets) -> td.PML:
r"""Computes the PML boundary condition based on the active socket set.
@ -173,13 +181,11 @@ class PMLBoundCondNode(base.MaxwellSimNode):
- **Simple**: Use `tidy3d`'s default parameters for defining the PML conductor (apart from number of layers).
- **Full**: Use the user-defined $\sigma$, $\kappa$, and $\alpha$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
output_params = output_sockets['BC']
layers = input_sockets['Layers']
output_params = output_sockets['BC']
has_output_params = not FS.check(output_params)
has_layers = not ct.FlowSignal.check(layers)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and has_layers and not output_params.symbols:
if has_output_params and not output_params.symbols:
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Simple':
@ -193,12 +199,12 @@ class PMLBoundCondNode(base.MaxwellSimNode):
alpha_order = input_sockets['α Order']
alpha_range = input_sockets['α Range']
has_sigma_order = not ct.FlowSignal.check(sigma_order)
has_sigma_range = not ct.FlowSignal.check(sigma_range)
has_kappa_order = not ct.FlowSignal.check(kappa_order)
has_kappa_range = not ct.FlowSignal.check(kappa_range)
has_alpha_order = not ct.FlowSignal.check(alpha_order)
has_alpha_range = not ct.FlowSignal.check(alpha_range)
has_sigma_order = not FS.check(sigma_order)
has_sigma_range = not FS.check(sigma_range)
has_kappa_order = not FS.check(kappa_order)
has_kappa_range = not FS.check(kappa_range)
has_alpha_order = not FS.check(alpha_order)
has_alpha_range = not FS.check(alpha_range)
if (
has_sigma_order
@ -223,18 +229,26 @@ class PMLBoundCondNode(base.MaxwellSimNode):
),
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
inscks_kinds={
'Layers': FK.Func,
'σ Order': FK.Func,
'σ Range': FK.Func,
'κ Order': FK.Func,
'κ Range': FK.Func,
'α Order': FK.Func,
'α Range': FK.Func,
},
input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@ -242,101 +256,86 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
input_socket_kinds={
'Layers': ct.FlowKind.Func,
'σ Order': ct.FlowKind.Func,
'σ Range': ct.FlowKind.Func,
'κ Order': ct.FlowKind.Func,
'κ Range': ct.FlowKind.Func,
'α Order': ct.FlowKind.Func,
'α Range': ct.FlowKind.Func,
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
'κ Order': True,
'κ Range': True,
'α Order': True,
'α Range': True,
},
output_sockets={'BC'},
output_socket_kinds={'BC': ct.FlowKind.Params},
)
def compute_pml_func(self, props, input_sockets, output_sockets) -> td.PML:
output_params = output_sockets['BC']
def compute_pml_func(self, props, input_sockets) -> td.PML:
r"""Computes the PML boundary condition as a lazy function."""
layers = input_sockets['Layers']
active_socket_set = props['active_socket_set']
has_output_params = not ct.FlowSignal.check(output_params)
has_layers = not ct.FlowSignal.check(layers)
match active_socket_set:
case 'Simple':
return layers.compose_within(
enclosing_func=lambda layers: td.PML(num_layers=layers),
supports_jax=False,
)
if has_output_params and has_layers:
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Simple':
return layers.compose_within(
enclosing_func=lambda layers: td.PML(num_layers=layers),
supports_jax=False,
case 'Full':
sigma_order = input_sockets['σ Order']
sigma_range = input_sockets['σ Range']
kappa_order = input_sockets['κ Order']
kappa_range = input_sockets['κ Range']
alpha_order = input_sockets['α Order']
alpha_range = input_sockets['α Range']
has_sigma_order = not FS.check(sigma_order)
has_sigma_range = not FS.check(sigma_range)
has_kappa_order = not FS.check(kappa_order)
has_kappa_range = not FS.check(kappa_range)
has_alpha_order = not FS.check(alpha_order)
has_alpha_range = not FS.check(alpha_range)
if (
has_sigma_order
and has_sigma_range
and has_kappa_order
and has_kappa_range
and has_alpha_order
and has_alpha_range
):
return (
sigma_order
| sigma_range
| kappa_order
| kappa_range
| alpha_order
| alpha_range
).compose_within(
enclosing_func=lambda els: td.PML(
num_layers=layers,
parameters=td.PMLParams(
sigma_order=els[0],
sigma_min=els[1][0],
sigma_max=els[1][1],
kappa_order=els[2],
kappa_min=els[3][0],
kappa_max=els[3][1],
alpha_order=els[4][1],
alpha_min=els[5][0],
alpha_max=els[5][1],
),
)
)
case 'Full':
sigma_order = input_sockets['σ Order']
sigma_range = input_sockets['σ Range']
kappa_order = input_sockets['κ Order']
kappa_range = input_sockets['κ Range']
alpha_order = input_sockets['α Order']
alpha_range = input_sockets['α Range']
has_sigma_order = not ct.FlowSignal.check(sigma_order)
has_sigma_range = not ct.FlowSignal.check(sigma_range)
has_kappa_order = not ct.FlowSignal.check(kappa_order)
has_kappa_range = not ct.FlowSignal.check(kappa_range)
has_alpha_order = not ct.FlowSignal.check(alpha_order)
has_alpha_range = not ct.FlowSignal.check(alpha_range)
if (
has_sigma_order
and has_sigma_range
and has_kappa_order
and has_kappa_range
and has_alpha_order
and has_alpha_range
):
return (
sigma_order
| sigma_range
| kappa_order
| kappa_range
| alpha_order
| alpha_range
).compose_within(
enclosing_func=lambda els: td.PML(
num_layers=layers,
parameters=td.PMLParams(
sigma_order=els[0],
sigma_min=els[1][0],
sigma_max=els[1][1],
kappa_order=els[2],
kappa_min=els[3][0],
kappa_max=els[3][1],
alpha_order=els[4][1],
alpha_min=els[5][0],
alpha_max=els[5][1],
),
)
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BC',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'active_socket_set'},
input_sockets={
'Layers',
inscks_kinds={
'Layers': FK.Params,
'σ Order': FK.Params,
'σ Range': FK.Params,
'κ Order': FK.Params,
'κ Range': FK.Params,
'α Order': FK.Params,
'α Range': FK.Params,
},
input_sockets_optional={
'σ Order',
'σ Range',
'κ Order',
@ -344,72 +343,53 @@ class PMLBoundCondNode(base.MaxwellSimNode):
'α Order',
'α Range',
},
input_socket_kinds={
'Layers': ct.FlowKind.Params,
'σ Order': ct.FlowKind.Params,
'σ Range': ct.FlowKind.Params,
'κ Order': ct.FlowKind.Params,
'κ Range': ct.FlowKind.Params,
'α Order': ct.FlowKind.Params,
'α Range': ct.FlowKind.Params,
},
input_sockets_optional={
'σ Order': True,
'σ Range': True,
'κ Order': True,
'κ Range': True,
'α Order': True,
'α Range': True,
},
)
def compute_pml_params(self, props, input_sockets) -> td.PML:
def compute_pml_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
r"""Computes the PML boundary condition based on the active socket set.
- **Simple**: Use `tidy3d`'s default parameters for defining the PML conductor (apart from number of layers).
- **Full**: Use the user-defined $\sigma$, $\kappa$, and $\alpha$ parameters, specifically polynomial order and sim-relative min/max conductivity values.
"""
layers = input_sockets['Layers']
has_layers = not ct.FlowSignal.check(layers)
active_socket_set = props['active_socket_set']
if has_layers:
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Simple':
return layers
match active_socket_set:
case 'Simple':
return layers
case 'Full':
sigma_order = input_sockets['σ Order']
sigma_range = input_sockets['σ Range']
kappa_order = input_sockets['σ Order']
kappa_range = input_sockets['σ Range']
alpha_order = input_sockets['σ Order']
alpha_range = input_sockets['σ Range']
case 'Full':
sigma_order = input_sockets['σ Order']
sigma_range = input_sockets['σ Range']
kappa_order = input_sockets['σ Order']
kappa_range = input_sockets['σ Range']
alpha_order = input_sockets['σ Order']
alpha_range = input_sockets['σ Range']
has_sigma_order = not ct.FlowSignal.check(sigma_order)
has_sigma_range = not ct.FlowSignal.check(sigma_range)
has_kappa_order = not ct.FlowSignal.check(kappa_order)
has_kappa_range = not ct.FlowSignal.check(kappa_range)
has_alpha_order = not ct.FlowSignal.check(alpha_order)
has_alpha_range = not ct.FlowSignal.check(alpha_range)
has_sigma_order = not FS.check(sigma_order)
has_sigma_range = not FS.check(sigma_range)
has_kappa_order = not FS.check(kappa_order)
has_kappa_range = not FS.check(kappa_range)
has_alpha_order = not FS.check(alpha_order)
has_alpha_range = not FS.check(alpha_range)
if (
has_sigma_order
and has_sigma_range
and has_kappa_order
and has_kappa_range
and has_alpha_order
and has_alpha_range
):
return (
sigma_order
| sigma_range
| kappa_order
| kappa_range
| alpha_order
| alpha_range
)
if (
has_sigma_order
and has_sigma_range
and has_kappa_order
and has_kappa_range
and has_alpha_order
and has_alpha_range
):
return (
sigma_order
| sigma_range
| kappa_order
| kappa_range
| alpha_order
| alpha_range
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
@ -418,4 +398,4 @@ class PMLBoundCondNode(base.MaxwellSimNode):
BL_REGISTER = [
PMLBoundCondNode,
]
BL_NODES = {ct.NodeType.PMLBoundCond: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
BL_NODES = {ct.NodeType.PMLBoundCond: (ct.NodeCategory.MAXWELLSIM_SIMS_BOUNDCONDFACES)}

View File

@ -30,6 +30,8 @@ log = logger.get(__name__)
SSA = ct.SimSpaceAxis
FK = ct.FlowKind
FS = ct.FlowSignal
class BoundCondsNode(base.MaxwellSimNode):
@ -97,27 +99,26 @@ class BoundCondsNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'BCs',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_sockets_optional={
'X': True,
'Y': True,
'Z': True,
'+X': True,
'-X': True,
'+Y': True,
'-Y': True,
'+Z': True,
'-Z': True,
'X',
'Y',
'Z',
'+X',
'-X',
'+Y',
'-Y',
'+Z',
'-Z',
},
output_sockets={'BCs'},
output_socket_kinds={'BCs': ct.FlowKind.Params},
outscks_kinds={'BCs': FK.Params},
)
def compute_bcs_value(self, input_sockets, output_sockets) -> td.BoundarySpec:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
output_params = output_sockets['BCs']
has_output_params = not ct.FlowSignal.check(output_params)
has_output_params = not FS.check(output_params)
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@ -125,9 +126,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
has_doubled_x = not FS.check_single(x, FS.NoFlow)
has_doubled_y = not FS.check_single(y, FS.NoFlow)
has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@ -138,8 +139,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
has_x_pos = not ct.FlowSignal.check(x_pos)
has_x_neg = not ct.FlowSignal.check(x_neg)
has_x_pos = not FS.check(x_pos)
has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@ -149,8 +150,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
has_y_pos = not ct.FlowSignal.check(y_pos)
has_y_neg = not ct.FlowSignal.check(y_neg)
has_y_pos = not FS.check(y_pos)
has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@ -160,8 +161,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
has_z_pos = not ct.FlowSignal.check(z_pos)
has_z_neg = not ct.FlowSignal.check(z_neg)
has_z_pos = not FS.check(z_pos)
has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@ -187,39 +188,29 @@ class BoundCondsNode(base.MaxwellSimNode):
minus=z_neg,
),
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'BCs',
kind=ct.FlowKind.Func,
kind=FK.Func,
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_socket_kinds={
'X': ct.FlowKind.Func,
'Y': ct.FlowKind.Func,
'Z': ct.FlowKind.Func,
'+X': ct.FlowKind.Func,
'-X': ct.FlowKind.Func,
'+Y': ct.FlowKind.Func,
'-Y': ct.FlowKind.Func,
'+Z': ct.FlowKind.Func,
'-Z': ct.FlowKind.Func,
},
input_sockets_optional={
'X': True,
'Y': True,
'Z': True,
'+X': True,
'-X': True,
'+Y': True,
'-Y': True,
'+Z': True,
'-Z': True,
'X': FK.Func,
'Y': FK.Func,
'Z': FK.Func,
'+X': FK.Func,
'-X': FK.Func,
'+Y': FK.Func,
'-Y': FK.Func,
'+Z': FK.Func,
'-Z': FK.Func,
},
input_sockets_optional={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
)
def compute_bcs_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_bcs_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@ -227,9 +218,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
has_doubled_x = not FS.check_single(x, FS.NoFlow)
has_doubled_y = not FS.check_single(y, FS.NoFlow)
has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@ -240,8 +231,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
has_x_pos = not ct.FlowSignal.check(x_pos)
has_x_neg = not ct.FlowSignal.check(x_neg)
has_x_pos = not FS.check(x_pos)
has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@ -251,8 +242,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
has_y_pos = not ct.FlowSignal.check(y_pos)
has_y_neg = not ct.FlowSignal.check(y_neg)
has_y_pos = not FS.check(y_pos)
has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@ -262,8 +253,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
has_z_pos = not ct.FlowSignal.check(z_pos)
has_z_neg = not ct.FlowSignal.check(z_neg)
has_z_pos = not FS.check(z_pos)
has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@ -290,39 +281,29 @@ class BoundCondsNode(base.MaxwellSimNode):
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'BCs',
kind=ct.FlowKind.Params,
kind=FK.Params,
input_sockets={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
input_socket_kinds={
'X': ct.FlowKind.Params,
'Y': ct.FlowKind.Params,
'Z': ct.FlowKind.Params,
'+X': ct.FlowKind.Params,
'-X': ct.FlowKind.Params,
'+Y': ct.FlowKind.Params,
'-Y': ct.FlowKind.Params,
'+Z': ct.FlowKind.Params,
'-Z': ct.FlowKind.Params,
},
input_sockets_optional={
'X': True,
'Y': True,
'Z': True,
'+X': True,
'-X': True,
'+Y': True,
'-Y': True,
'+Z': True,
'-Z': True,
'X': FK.Params,
'Y': FK.Params,
'Z': FK.Params,
'+X': FK.Params,
'-X': FK.Params,
'+Y': FK.Params,
'-Y': FK.Params,
'+Z': FK.Params,
'-Z': FK.Params,
},
input_sockets_optional={'X', 'Y', 'Z', '+X', '-X', '+Y', '-Y', '+Z', '-Z'},
)
def compute_bcs_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_bcs_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
@ -330,9 +311,9 @@ class BoundCondsNode(base.MaxwellSimNode):
y = input_sockets['Y']
z = input_sockets['Z']
has_doubled_x = not ct.FlowSignal.check_single(x, ct.FlowSignal.NoFlow)
has_doubled_y = not ct.FlowSignal.check_single(y, ct.FlowSignal.NoFlow)
has_doubled_z = not ct.FlowSignal.check_single(z, ct.FlowSignal.NoFlow)
has_doubled_x = not FS.check_single(x, FS.NoFlow)
has_doubled_y = not FS.check_single(y, FS.NoFlow)
has_doubled_z = not FS.check_single(z, FS.NoFlow)
# Deduce +/- of Each Axis
## -> +/- X
@ -343,8 +324,8 @@ class BoundCondsNode(base.MaxwellSimNode):
x_pos = input_sockets['+X']
x_neg = input_sockets['-X']
has_x_pos = not ct.FlowSignal.check(x_pos)
has_x_neg = not ct.FlowSignal.check(x_neg)
has_x_pos = not FS.check(x_pos)
has_x_neg = not FS.check(x_neg)
## -> +/- Y
if has_doubled_y:
@ -354,8 +335,8 @@ class BoundCondsNode(base.MaxwellSimNode):
y_pos = input_sockets['+Y']
y_neg = input_sockets['-Y']
has_y_pos = not ct.FlowSignal.check(y_pos)
has_y_neg = not ct.FlowSignal.check(y_neg)
has_y_pos = not FS.check(y_pos)
has_y_neg = not FS.check(y_neg)
## -> +/- Z
if has_doubled_z:
@ -365,8 +346,8 @@ class BoundCondsNode(base.MaxwellSimNode):
z_pos = input_sockets['+Z']
z_neg = input_sockets['-Z']
has_z_pos = not ct.FlowSignal.check(z_pos)
has_z_neg = not ct.FlowSignal.check(z_neg)
has_z_pos = not FS.check(z_pos)
has_z_neg = not FS.check(z_neg)
if (
has_x_pos
@ -377,7 +358,7 @@ class BoundCondsNode(base.MaxwellSimNode):
and has_z_neg
):
return x_pos | x_neg | y_pos | y_neg | z_pos | z_neg
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
@ -386,4 +367,4 @@ class BoundCondsNode(base.MaxwellSimNode):
BL_REGISTER = [
BoundCondsNode,
]
BL_NODES = {ct.NodeType.BoundConds: (ct.NodeCategory.MAXWELLSIM_BOUNDS)}
BL_NODES = {ct.NodeType.BoundConds: (ct.NodeCategory.MAXWELLSIM_SIMS)}

View File

@ -16,13 +16,19 @@
"""Implements `FDTDSimNode`."""
import itertools
import typing as typ
import bpy
import jax
import numpy as np
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import frozendict
from ... import contracts as ct
from ... import sockets
@ -30,6 +36,44 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
SimArray: typ.TypeAlias = frozendict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
]
SimArrayInfo: typ.TypeAlias = frozendict[
tuple[sim_symbols.SimSymbol, ...], tuple[typ.Any, ...], td.Simulation
]
class RecomputeSimInfo(bpy.types.Operator):
"""Recompute information about the simulation."""
bl_idname = ct.OperatorType.NodeRecomputeSimInfo
bl_label = 'Recompute Simuation Info'
bl_description = (
'Recompute information of a simulation attached to a `FDTDSimNode`.'
)
@classmethod
def poll(cls, context):
"""Allow running whenever a particular FDTDSim node is available."""
return (
# Check Tidy3DWebExporter is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.FDTDSim
)
def execute(self, context):
"""Invalidate the `.sims` property, triggering reevaluation of all downstream information about the simulation."""
node = context.node
node.sims = bl_cache.Signal.InvalidateCache
return {'FINISHED'}
class FDTDSimNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
@ -44,39 +88,311 @@ class FDTDSimNode(base.MaxwellSimNode):
'BCs': sockets.MaxwellBoundCondsSocketDef(),
'Domain': sockets.MaxwellSimDomainSocketDef(),
'Sources': sockets.MaxwellSourceSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
'Structures': sockets.MaxwellStructureSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
'Monitors': sockets.MaxwellMonitorSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
}
output_socket_sets: typ.ClassVar = {
'Single': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Value),
},
'Func': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
'Batch': {
'Sims': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Array),
},
'Lazy': {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Func),
},
}
####################
# - Properties
# - Properties: UI
####################
differentiable: bool = bl_cache.BLField(False)
ui_limits: bool = bl_cache.BLField(False)
ui_discretization: bool = bl_cache.BLField(False)
ui_portability: bool = bl_cache.BLField(False)
ui_propagation: bool = bl_cache.BLField(False)
####################
# - UI
# - Properties: Simulation
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(
self,
self.blfields['differentiable'],
text='Differentiable',
toggle=True,
)
@bl_cache.cached_bl_property()
def sims(self) -> SimArray | None:
"""The complete description of all simulation output objects."""
if self.active_socket_set == 'Single':
sim_value = self.compute_output('Sim', kind=FK.Value)
has_sim_value = not FS.check(sim_value)
if has_sim_value:
return {(): sim_value}
elif self.active_socket_set == 'Batch':
sim_array = self.compute_output('Sims', kind=FK.Array)
has_sim_array = not FS.check(sim_array)
if has_sim_array:
return sim_array
return None
####################
# - Properties: Propagation
####################
@bl_cache.cached_bl_property(depends_on={'sims'})
def has_gain(self) -> SimArrayInfo | None:
"""Whether any mediums in the simulation allow gain."""
if self.sims is not None:
return {k: sim.allow_gain for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def has_complex_fields(self) -> SimArrayInfo | None:
"""Whether complex fields are currently used in the simulation."""
if self.sims is not None:
return {k: sim.complex_fields for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def min_wl(self) -> SimArrayInfo | None:
"""The smallest wavelength that occurs in the simulation."""
if self.sims is not None:
return {k: sim.wvl_mat_min * spu.um for k, sim in self.sims.items()}
return None
####################
# - Properties: Discretization
####################
@bl_cache.cached_bl_property(depends_on={'sims'})
def time_range(self) -> SimArrayInfo | None:
"""The time range of the simulation."""
if self.sims is not None:
return {
k: sp.Matrix(
[0, spu.convert_to(sim.run_time * spu.second, spu.picosecond)]
)
for k, sim in self.sims.items()
}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def freq_range(self) -> SimArrayInfo | None:
"""The total frequency range of the simulation, across all sources."""
if self.sims is not None:
return {
k: spu.convert_to(
sp.Matrix([sim.frequency_range[0], sim.frequency_range[1]])
* spu.hertz,
spux.THz,
)
for k, sim in self.sims.items()
}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def time_step(self) -> SimArrayInfo | None:
"""The time step of the simulation."""
if self.sims is not None:
return {k: sim.dt * spu.second for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def time_steps(self) -> SimArrayInfo | None:
"""The time step of the simulation."""
if self.sims is not None:
return {k: sim.num_time_steps for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def nyquist_step(self) -> SimArrayInfo | None:
"""The number of time-steps needed to theoretically provide for correctly resolved sampling of the simulation grid."""
if self.sims is not None:
return {k: sim.nyquist_step for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def num_cells(self) -> SimArrayInfo | None:
"""The number of 3D cells for which the simulation is discretized."""
if self.sims is not None:
return {k: sim.num_cells for k, sim in self.sims.items()}
return None
####################
# - Properties: Data
####################
@bl_cache.cached_bl_property(depends_on={'sims'})
def monitor_data_sizes(self) -> SimArrayInfo | None:
"""The total data expected to be taken by each monitors."""
if self.sims is not None:
return {k: sim.monitors_data_size for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'monitor_data_sizes'})
def total_monitor_data_size(self) -> SimArrayInfo | None:
"""The total data taken by the monitors."""
if self.monitor_data_sizes is not None:
return {
k: sum(sizes.values()) for k, sizes in self.monitor_data_sizes.items()
}
return None
####################
# - Properties: Lists
####################
@bl_cache.cached_bl_property(depends_on={'sims'})
def list_datasets(self) -> SimArrayInfo | None:
"""List of custom datasets required by the simulation."""
if self.sims is not None:
return {k: sim.custom_datasets for k, sim in self.sims.items()}
return None
@bl_cache.cached_bl_property(depends_on={'sims'})
def list_vol_structures(self) -> SimArrayInfo | None:
"""List of volumetric structures, where 2D mediums were converted to 3D."""
if self.sims is not None:
return {k: sim.volumetric_structures for k, sim in self.sims.items()}
return None
####################
# - Properties: Validated
####################
@bl_cache.cached_bl_property(depends_on={'sims'})
def sims_valid(self) -> SimArrayInfo | None:
"""Whether all sims are valid."""
if self.sims is not None:
validity = {}
for k, sim in self.sims.items(): # noqa: B007
try:
pass ## TODO: VERY slow, batch checking is infeasible
# sim.validate_pre_upload(source_required=True)
except td.exceptions.SetupError:
validity[k] = False
else:
validity[k] = True
return validity
return None
####################
# - Info
####################
@bl_cache.cached_bl_property(
depends_on={
'sims',
'time_range',
'freq_range',
'min_wl',
'num_cells',
'time_steps',
'nyquist_steps',
'time_step',
'sims_valid',
'total_monitor_data_size',
'has_gain',
'has_complex_fields',
'ui_limits',
'ui_discretization',
'ui_portability',
'ui_propagation',
}
)
def sim_labels(self) -> SimArrayInfo | None:
"""Pre-processed labels for efficient drawing of simulation info."""
if self.sims is not None:
sims_vals_labels = {}
for syms_vals in self.sims:
labels = []
if syms_vals:
labels += [
'|'.join(
[
f'{sym.name_pretty}={val:,.2f}'
for sym, val in zip(*syms_vals, strict=True)
]
),
]
labels += [
['Limits', 'ui_limits'],
]
if self.ui_limits:
labels += [
('max t', spux.sp_to_str(self.time_range[syms_vals][1])),
('min f', spux.sp_to_str(self.freq_range[syms_vals][0])),
('max f', spux.sp_to_str(self.freq_range[syms_vals][1])),
('min λ', spux.sp_to_str(self.min_wl[syms_vals].n(2))),
]
labels += [
['Discretization', 'ui_discretization'],
]
if self.ui_discretization:
labels += [
('cells', f'{self.num_cells[syms_vals]:,}'),
('num Δt', f'{self.time_steps[syms_vals]:,}'),
('nyq Δt', f'{self.nyquist_step[syms_vals]}'),
('Δt', spux.sp_to_str(self.time_step[syms_vals].n(2))),
]
labels += [
['Portability', 'ui_portability'],
]
if self.ui_portability:
labels += [
('Valid?', str(self.sims_valid[syms_vals])),
(
'Σ mon',
f'{self.total_monitor_data_size[syms_vals] / 1000000:,.2f}MB',
),
]
labels += [
['Propagation', 'ui_propagation'],
]
if self.ui_propagation:
labels += [
('Gain?', str(self.has_gain[syms_vals])),
('𝐄𝐇', str(self.has_complex_fields[syms_vals])),
]
sims_vals_labels[syms_vals] = labels
return sims_vals_labels
return None
def draw_info(self, _, layout: bpy.types.UILayout):
"""Draw information about the simulation, if any."""
row = layout.row(align=True)
row.alignment = 'CENTER'
row.label(text='Sim Info')
row.operator(ct.OperatorType.NodeRecomputeSimInfo, icon='FILE_REFRESH', text='')
# Simulation Info
if self.sim_labels is not None:
for labels in self.sim_labels.values():
box = layout.box()
for el in labels:
# Header
if isinstance(el, list):
row = box.row(align=True)
# row.alignment = 'EXPAND'
row.prop(self, self.blfields[el[1]], text=el[0], toggle=True)
# row.label(text=el)
split = box.split(factor=0.4)
col_l = split.column(align=True)
col_r = split.column(align=True)
# Label Pair
elif isinstance(el, tuple):
col_l.label(text=el[0])
col_r.label(text=el[1])
else:
raise TypeError
break
####################
# - Events
@ -88,36 +404,31 @@ class FDTDSimNode(base.MaxwellSimNode):
# Loaded
props={'active_socket_set'},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
outscks_kinds={'Sim': FK.Params},
)
def on_any_changed(self, props, output_sockets) -> None:
"""Create loose input sockets."""
"""Manage loose input sockets in response to symbolic simulation elements."""
# Loose Input Sockets
output_params = output_sockets['Sim']
has_output_params = not ct.FlowSignal.check(output_params)
active_socket_set = props['active_socket_set']
if (
active_socket_set == 'Single'
and has_output_params
active_socket_set in ['Single', 'Batch']
and output_params.symbols
and set(self.loose_input_sockets)
!= {sym.name for sym in output_params.symbols}
):
if set(self.loose_input_sockets) != {
sym.name for sym in output_params.symbols
}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
'active_kind': ct.FlowKind.Value,
'use_value_range_swapper': (
active_socket_set == 'Value'
),
}
)
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
expr_info
| {
'active_kind': FK.Value,
'use_value_range_swapper': (active_socket_set == 'Batch'),
}
)
for sym, expr_info in output_params.sym_expr_infos.items()
}
)
for sym, expr_info in output_params.sym_expr_infos.items()
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
@ -127,144 +438,184 @@ class FDTDSimNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
props={'active_socket_set'},
outscks_kinds={'Sim': {FK.Func, FK.Params}},
all_loose_input_sockets=True,
output_sockets={'Sim'},
output_socket_kinds={'Sim': {ct.FlowKind.Func, ct.FlowKind.Params}},
loose_input_sockets_kind={FK.Func, FK.Params},
)
def compute_value(
self, loose_input_sockets, output_sockets
) -> ct.ParamsFlow | ct.FlowSignal:
self, props, loose_input_sockets, output_sockets
) -> td.Simulation | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Sim'][ct.FlowKind.Func]
output_params = output_sockets['Sim'][ct.FlowKind.Params]
func = output_sockets['Sim'][FK.Func]
params = output_sockets['Sim'][FK.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
has_func = not FS.check(func)
has_params = not FS.check(params)
if has_output_func and has_output_params:
return output_func.realize(
output_params,
symbol_values={
sym: loose_input_sockets[sym.name]
for sym in output_params.sorted_symbols
},
disallow_jax=True,
active_socket_set = props['active_socket_set']
if has_func and has_params and active_socket_set == 'Single':
symbol_values = {
sym: events.realize_known(loose_input_sockets[sym.name])
for sym in params.sorted_symbols
}
return func.realize(
params,
symbol_values=frozendict(
{
sym: events.realize_known(loose_input_sockets[sym.name])
for sym in params.sorted_symbols
}
),
).updated_copy(
attrs=dict(
ct.SimMetadata(
realizations=ct.SimRealizations(
syms=tuple(symbol_values.keys()),
vals=tuple(symbol_values.values()),
)
)
)
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind.Array
####################
@events.computes_output_socket(
'Sims',
kind=FK.Array,
# Loaded
props={'active_socket_set'},
outscks_kinds={'Sim': {FK.Func, FK.Params}},
all_loose_input_sockets=True,
loose_input_sockets_kind={FK.Func, FK.Params},
)
def compute_array(
self, props, loose_input_sockets, output_sockets
) -> SimArray | FS:
"""Produce a batch of simulations as a dictionary, indexed by a twice-nested tuple matching a `SimSymbol` tuple to a corresponding tuple of values."""
func = output_sockets['Sim'][FK.Func]
params = output_sockets['Sim'][FK.Params]
active_socket_set = props['active_socket_set']
if active_socket_set == 'Batch':
# Realize Values per-Symbol
## -> First, we realize however many values requested per symbol.
sym_datas: dict[sim_symbols.SimSymbol, list] = {}
for sym in params.sorted_symbols:
if sym.name not in loose_input_sockets:
return FS.FlowPending
# Realize Data for Symbol
## -> This may be a single scalar/vector/matrix.
## -> This may also be many _scalars_.
sym_data = events.realize_known(
loose_input_sockets[sym.name],
freeze=True,
)
if sym_data is None:
return FS.FlowPending
# Single Value per-Symbol
if sym.shape_len == 0:
sym_datas |= {sym: (sym_data,)}
# Many Values per-Symbol
else:
sym_datas |= {sym: sym_data}
# Realize Function per-Combination
## -> td.Simulation requires single, specific values for all syms.
## -> What we have is many specific values for each sym.
## -> With a single comprehension, we resolve this difference.
## -> The end-result is an annotated td.Simulation per-combo.
## -> NOTE: This might be big! Such are parameter-sweeps.
syms = tuple(sym_datas.keys())
return {
(syms, vals): func.realize(
params,
symbol_values=frozendict(zip(syms, vals, strict=True)),
).updated_copy(
attrs=dict(
ct.SimMetadata(
realizations=ct.SimRealizations(syms=syms, vals=vals)
)
)
)
for vals in itertools.product(*sym_datas.values())
}
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'BCs': ct.FlowKind.Func,
'Domain': ct.FlowKind.Func,
'Sources': ct.FlowKind.Func,
'Structures': ct.FlowKind.Func,
'Monitors': ct.FlowKind.Func,
inscks_kinds={
'BCs': FK.Func,
'Domain': FK.Func,
'Sources': FK.Func,
'Structures': FK.Func,
'Monitors': FK.Func,
},
output_sockets={'Sim'},
output_socket_kinds={'Sim': ct.FlowKind.Params},
)
def compute_fdtd_sim_func(
self, props, input_sockets, output_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
def compute_fdtd_sim_func(self, input_sockets) -> ct.FuncFlow:
"""Compute a single simulation, given that all inputs are non-symbolic."""
bounds = input_sockets['BCs']
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
monitors = input_sockets['Monitors']
output_params = output_sockets['Sim']
has_bounds = not ct.FlowSignal.check(bounds)
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_monitors = not ct.FlowSignal.check(monitors)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_sim_domain
and has_sources
and has_structures
and has_bounds
and has_monitors
and has_output_params
):
differentiable = props['differentiable']
if differentiable:
raise NotImplementedError
return (
bounds | sim_domain | sources | structures | monitors
).compose_within(
enclosing_func=lambda els: td.Simulation(
boundary_spec=els[0],
**els[1],
sources=els[2],
structures=els[3],
monitors=els[4],
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
return (bounds | sim_domain | sources | structures | monitors).compose_within(
lambda els: td.Simulation(
boundary_spec=els[0],
**els[1],
sources=els[2],
structures=els[3],
monitors=els[4],
),
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Sim',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'differentiable'},
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
input_socket_kinds={
'BCs': ct.FlowKind.Params,
'Domain': ct.FlowKind.Params,
'Sources': ct.FlowKind.Params,
'Structures': ct.FlowKind.Params,
'Monitors': ct.FlowKind.Params,
inscks_kinds={
'BCs': FK.Params,
'Domain': FK.Params,
'Sources': FK.Params,
'Structures': FK.Params,
'Monitors': FK.Params,
},
)
def compute_fdtd_sim_params(
self, props, input_sockets
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
"""Compute a single simulation, given that all inputs are non-symbolic."""
def compute_params(self, input_sockets) -> td.Simulation | FS:
"""Compute all function parameters needed to create the simulation."""
# Compute Output Parameters
bounds = input_sockets['BCs']
sim_domain = input_sockets['Domain']
sources = input_sockets['Sources']
structures = input_sockets['Structures']
monitors = input_sockets['Monitors']
has_bounds = not ct.FlowSignal.check(bounds)
has_sim_domain = not ct.FlowSignal.check(sim_domain)
has_sources = not ct.FlowSignal.check(sources)
has_structures = not ct.FlowSignal.check(structures)
has_monitors = not ct.FlowSignal.check(monitors)
if (
has_bounds
and has_sim_domain
and has_sources
and has_structures
and has_monitors
):
return bounds | sim_domain | sources | structures | monitors
return ct.FlowSignal.FlowPending
return bounds | sim_domain | sources | structures | monitors
####################
# - Blender Registration
####################
BL_REGISTER = [
RecomputeSimInfo,
FDTDSimNode,
]
BL_NODES = {ct.NodeType.FDTDSim: (ct.NodeCategory.MAXWELLSIM_SIMS)}

View File

@ -22,8 +22,8 @@ import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import managed_objs, sockets
@ -31,6 +31,11 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class SimDomainNode(base.MaxwellSimNode):
"""The domain of a simulation in space and time, including bounds, discretization strategy, and the ambient medium."""
@ -41,24 +46,24 @@ class SimDomainNode(base.MaxwellSimNode):
input_sockets: typ.ClassVar = {
'Duration': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Time,
physical_type=PT.Time,
default_unit=spu.picosecond,
default_value=5,
abs_min=0,
),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
mathtype=MT.Real,
physical_type=PT.Length,
default_unit=spu.micrometer,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
mathtype=MT.Real,
physical_type=PT.Length,
default_unit=spu.micrometer,
default_value=sp.Matrix([1, 1, 1]),
default_value=sp.ImmutableMatrix([1, 1, 1]),
abs_min=0.001,
),
'Grid': sockets.MaxwellSimGridSocketDef(),
@ -77,89 +82,75 @@ class SimDomainNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Domain'},
output_socket_kinds={'Domain': {ct.FlowKind.Func, ct.FlowKind.Params}},
outscks_kinds={
'Domain': {FK.Func, FK.Params},
},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Domain'][ct.FlowKind.Func]
output_params = output_sockets['Domain'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Domain'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
'Duration': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func,
'Grid': ct.FlowKind.Func,
'Ambient Medium': ct.FlowKind.Func,
'Duration': FK.Func,
'Center': FK.Func,
'Size': FK.Func,
'Grid': FK.Func,
'Ambient Medium': FK.Func,
},
scale_input_sockets={
'Duration': ct.UNITS_TIDY3D,
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
},
)
def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_domain_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
size = input_sockets['Size']
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return (
duration.scale_to_unit_system(ct.UNITS_TIDY3D)
| center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| grid
| medium
).compose_within(
lambda els: {
'run_time': els[0],
'center': els[1].flatten().tolist(),
'size': els[2].flatten().tolist(),
'grid_spec': els[3],
'medium': els[4],
},
supports_jax=False,
)
return ct.FlowSignal.FlowPending
return (duration | center | size | grid | medium).compose_within(
lambda els: {
'run_time': els[0],
'center': els[1].flatten().tolist(),
'size': els[2].flatten().tolist(),
'grid_spec': els[3],
'medium': els[4],
},
supports_jax=False,
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Duration', 'Center', 'Size', 'Grid', 'Ambient Medium'},
input_socket_kinds={
'Duration': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Size': ct.FlowKind.Params,
'Grid': ct.FlowKind.Params,
'Ambient Medium': ct.FlowKind.Params,
'Duration': FK.Params,
'Center': FK.Params,
'Size': FK.Params,
'Grid': FK.Params,
'Ambient Medium': FK.Params,
},
)
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_domain_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the output `ParamsFlow` of the simulation domain from strictly non-symbolic inputs."""
duration = input_sockets['Duration']
center = input_sockets['Center']
@ -167,22 +158,14 @@ class SimDomainNode(base.MaxwellSimNode):
grid = input_sockets['Grid']
medium = input_sockets['Ambient Medium']
has_duration = not ct.FlowSignal.check(duration)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_grid = not ct.FlowSignal.check(grid)
has_medium = not ct.FlowSignal.check(medium)
if has_duration and has_center and has_size and has_grid and has_medium:
return duration | center | size | grid | medium
return ct.FlowSignal.FlowPending
return duration | center | size | grid | medium
####################
# - Preview
####################
@events.computes_output_socket(
'Domain',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
@ -192,35 +175,37 @@ class SimDomainNode(base.MaxwellSimNode):
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Size'},
socket_name={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
run_on_init=True,
# Loaded
input_sockets={'Center', 'Size'},
managed_objs={'modifier'},
output_sockets={'Domain'},
output_socket_kinds={'Domain': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
},
)
def on_input_changed(self, managed_objs, input_sockets, output_sockets) -> None:
def on_input_changed(self, managed_objs, input_sockets) -> None:
"""Preview the simulation domain based on input parameters, so long as they are not dependent on unrealized symbols."""
output_params = output_sockets['Domain']
center = input_sockets['Center']
center = events.realize_preview(input_sockets['Center'])
size = events.realize_preview(input_sockets['Size'])
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Size': input_sockets['Size'],
},
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SimulationSimDomain),
'inputs': {
'Size': size,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
},
location=center,
)
####################

View File

@ -14,8 +14,164 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `SimGridNode`."""
import typing as typ
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import logger
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
SSA = ct.SimSpaceAxis
FK = ct.FlowKind
FS = ct.FlowSignal
class SimGridNode(base.MaxwellSimNode):
"""Provides a hub for joining custom simulation domain boundary conditions by-axis."""
node_type = ct.NodeType.SimGrid
bl_label = 'Sim Grid'
####################
# - Sockets
####################
input_sockets: typ.ClassVar = {
'X': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
'Y': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
'Z': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
}
input_socket_sets: typ.ClassVar = {
'Relative': {},
'Absolute': {
'WL': sockets.ExprSocketDef(
default_unit=spu.nm,
default_value=500,
abs_min=0,
abs_min_closed=False,
)
},
}
output_sockets: typ.ClassVar = {
'Grid': sockets.MaxwellSimGridSocketDef(),
}
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Grid',
kind=FK.Value,
# Loaded
outscks_kinds={'Grid': {FK.Func, FK.Params}},
)
def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
value = events.realize_known(output_sockets['Grid'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Grid',
kind=FK.Func,
# Loaded
props={'active_socket_set'},
inscks_kinds={
'X': FK.Func,
'Y': FK.Func,
'Z': FK.Func,
'WL': FK.Func,
},
input_sockets_optional={'WL'},
scale_input_sockets={
'WL': ct.UNITS_TIDY3D,
},
)
def compute_grid_func(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation grid lazily, at the specified wavelength."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
x = input_sockets['X']
y = input_sockets['Y']
z = input_sockets['Z']
wl = input_sockets['WL']
active_socket_set = props['active_socket_set']
common_func = x | y | z
match active_socket_set:
case 'Absolute' if not FS.check(wl):
return (common_func | wl).compose_within(
lambda els: td.GridSpec(
grid_x=els[0], grid_y=els[1], grid_z=els[2], wavelength=els[3]
)
)
case 'Relative':
return common_func.compose_within(
lambda els: td.GridSpec(
grid_x=els[0],
grid_y=els[1],
grid_z=els[2],
)
)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Grid',
kind=FK.Params,
# Loaded
props={'active_socket_set'},
inscks_kinds={
'X': FK.Params,
'Y': FK.Params,
'Z': FK.Params,
'WL': FK.Params,
},
input_sockets_optional={'WL'},
)
def compute_bcs_params(self, props, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
x = input_sockets['X']
y = input_sockets['Y']
z = input_sockets['Z']
wl = input_sockets['WL']
active_socket_set = props['active_socket_set']
common_params = x | y | z
match active_socket_set:
case 'Relative':
return common_params
case 'Absolute' if not FS.check(wl):
return common_params | wl
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
SimGridNode,
]
BL_NODES = {ct.NodeType.SimGrid: (ct.NodeCategory.MAXWELLSIM_SIMS)}

View File

@ -14,8 +14,126 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `AutoSimGridAxisNode`."""
import typing as typ
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
SSA = ct.SimSpaceAxis
FK = ct.FlowKind
FS = ct.FlowSignal
class AutoSimGridAxisNode(base.MaxwellSimNode):
"""Declare a uniform grid along a simulation axis."""
node_type = ct.NodeType.AutoSimGridAxis
bl_label = 'Auto Grid Axis'
####################
# - Sockets
####################
input_sockets: typ.ClassVar = {
'min N/λ': sockets.ExprSocketDef(
default_value=10,
),
'min Δℓ': sockets.ExprSocketDef(
default_unit=spu.nm,
default_value=0,
abs_min=0,
),
'max ratio': sockets.ExprSocketDef(
default_value=1.4,
abs_min=1,
),
}
output_sockets: typ.ClassVar = {
'Grid Axis': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
}
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Value,
# Loaded
outscks_kinds={'Grid Axis': {FK.Func, FK.Params}},
)
def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
value = events.realize_known(output_sockets['Grid Axis'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Func,
# Loaded
inscks_kinds={
'min N/λ': FK.Func,
'max ratio': FK.Func,
'min Δℓ': FK.Func,
},
scale_input_sockets={
'min Δℓ': ct.UNITS_TIDY3D,
},
)
def compute_grid_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation grid lazily, at the specified wavelength."""
min_steps_per_wl = input_sockets['min N/λ']
max_consecutive_ratio = input_sockets['max ratio']
min_length = input_sockets['min Δℓ']
return (min_steps_per_wl | max_consecutive_ratio | min_length).compose_within(
lambda els: td.AutoGrid(
min_steps_per_wvl=els[0],
max_scale=els[1],
dl_min=els[2],
)
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Params,
# Loaded
inscks_kinds={
'min N/λ': FK.Params,
'max ratio': FK.Params,
'min Δℓ': FK.Params,
},
)
def compute_grid_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation grid lazily, at the specified wavelength."""
min_steps_per_wl = input_sockets['min N/λ']
max_consecutive_ratio = input_sockets['max ratio']
min_length = input_sockets['min Δℓ']
return min_steps_per_wl | max_consecutive_ratio | min_length
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
AutoSimGridAxisNode,
]
BL_NODES = {ct.NodeType.AutoSimGridAxis: (ct.NodeCategory.MAXWELLSIM_SIMS_SIMGRIDAXES)}

View File

@ -14,8 +14,111 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `UniformSimGridAxisNode`."""
import typing as typ
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import logger
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
SSA = ct.SimSpaceAxis
FK = ct.FlowKind
FS = ct.FlowSignal
class UniformSimGridAxisNode(base.MaxwellSimNode):
"""Declare a uniform grid along a simulation axis."""
node_type = ct.NodeType.UniformSimGridAxis
bl_label = 'Uniform Grid Axis'
####################
# - Sockets
####################
input_sockets: typ.ClassVar = {
'Δℓ': sockets.ExprSocketDef(
default_unit=spu.nm,
default_value=10,
),
}
output_sockets: typ.ClassVar = {
'Grid Axis': sockets.MaxwellSimGridAxisSocketDef(active_kind=FK.Func),
}
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Value,
# Loaded
outscks_kinds={'Grid Axis': {FK.Func, FK.Params}},
)
def compute_bcs_value(self, output_sockets) -> td.BoundarySpec:
"""Compute the simulation boundary conditions, by combining the individual input by specified half axis."""
value = events.realize_known(output_sockets['Grid Axis'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Func,
# Loaded
inscks_kinds={
'Δℓ': FK.Func,
},
scale_input_sockets={
'Δℓ': ct.UNITS_TIDY3D,
},
)
def compute_grid_func(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation grid lazily, at the specified wavelength."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
dl = input_sockets['Δℓ']
return dl.compose_within(
lambda _dl: td.UniformGrid(dl=_dl),
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Grid Axis',
kind=FK.Params,
# Loaded
inscks_kinds={
'Δℓ': FK.Params,
},
)
def compute_grid_params(self, input_sockets) -> ct.ParamsFlow | FS:
"""Compute the simulation grid lazily, at the specified wavelength."""
# Deduce "Doubledness"
## -> A "doubled" axis defines the same bound cond both ways
dl = input_sockets['Δℓ']
return dl # noqa: RET504
####################
# - Blender Registration
####################
BL_REGISTER = []
BL_NODES = {}
BL_REGISTER = [
UniformSimGridAxisNode,
]
BL_NODES = {
ct.NodeType.UniformSimGridAxis: (ct.NodeCategory.MAXWELLSIM_SIMS_SIMGRIDAXES)
}

View File

@ -0,0 +1,32 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import (
eme_solver,
fdtd_solver,
mode_solver,
)
BL_REGISTER = [
*fdtd_solver.BL_REGISTER,
*mode_solver.BL_REGISTER,
*eme_solver.BL_REGISTER,
]
BL_NODES = {
**fdtd_solver.BL_NODES,
**mode_solver.BL_NODES,
**eme_solver.BL_NODES,
}

View File

@ -14,13 +14,5 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import bound_cond_nodes, bound_conds
BL_REGISTER = [
*bound_conds.BL_REGISTER,
*bound_cond_nodes.BL_REGISTER,
]
BL_NODES = {
**bound_conds.BL_NODES,
**bound_cond_nodes.BL_NODES,
}
BL_REGISTER = []
BL_NODES = {}

View File

@ -0,0 +1,337 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `FDTDSolverNode`."""
import typing as typ
import bpy
from blender_maxwell.services import tdcloud
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
####################
# - Operators
####################
class RunSimulation(bpy.types.Operator):
"""Run a Tidy3D simulation given to a `FDTDSolverNode`."""
bl_idname = ct.OperatorType.NodeRunSimulation
bl_label = 'Run Sim'
bl_description = 'Run the currently tracked simulation task'
@classmethod
def poll(cls, context):
"""Allow running when there are runnable tasks."""
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
# Check FDTDSolverNode is Accessible
and hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.FDTDSolver
# Check Task is Runnable
and context.node.are_tasks_runnable
)
def execute(self, context):
"""Run all uploaded, runnable tasks."""
node = context.node
for cloud_task in node.cloud_tasks:
log.debug('Submitting Cloud Task %s', cloud_task.task_id)
cloud_task.submit()
return {'FINISHED'}
class ReloadTrackedTask(bpy.types.Operator):
"""Reload information of the selected task in a `FDTDSolverNode`."""
bl_idname = ct.OperatorType.NodeReloadTrackedTask
bl_label = 'Reload Tracked Tidy3D Cloud Task'
bl_description = 'Reload the currently tracked simulation task'
@classmethod
def poll(cls, context):
"""Always allow reloading tasks."""
return (
# Check Tidy3D Cloud
tdcloud.IS_AUTHENTICATED
# Check FDTDSolverNode is Accessible
and hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.FDTDSolver
)
def execute(self, context):
"""Reload all tasks in all folders for which there are uploaded tasks in the node."""
node = context.node
for folder_id in {cloud_task.folder_id for cloud_task in node.cloud_tasks}:
tdcloud.TidyCloudTasks.update_tasks(folder_id)
return {'FINISHED'}
####################
# - Node
####################
class FDTDSolverNode(base.MaxwellSimNode):
"""Solve an FDTD simulation problem using the Tidy3D cloud solver."""
node_type = ct.NodeType.FDTDSolver
bl_label = 'FDTD Solver'
input_socket_sets: typ.ClassVar = {
'Single': {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=True,
),
},
'Batch': {
'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Array,
should_exist=True,
),
},
}
output_socket_sets: typ.ClassVar = {
'Single': {
'Cloud Task': sockets.Tidy3DCloudTaskSocketDef(
should_exist=True,
),
},
'Batch': {
'Cloud Tasks': sockets.Tidy3DCloudTaskSocketDef(
active_kind=FK.Array,
should_exist=True,
),
},
}
####################
# - Properties: Incoming InfoFlow
####################
@events.on_value_changed(
# Trigger
socket_name={'Cloud Task': FK.Value, 'Cloud Tasks': FK.Array},
)
def on_cloud_tasks_changed(self) -> None: # noqa: D102
self.cloud_tasks = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def cloud_tasks(self) -> list[tdcloud.CloudTask] | None:
"""Retrieve the current cloud tasks from the input.
If one can't be loaded, return None.
"""
if self.active_socket_set == 'Single':
cloud_task_single = self._compute_input(
'Cloud Task',
kind=FK.Value,
)
has_cloud_task_single = not FS.check(cloud_task_single)
if has_cloud_task_single and cloud_task_single is not None:
return [cloud_task_single]
if self.active_socket_set == 'Batch':
cloud_task_array = self._compute_input(
'Cloud Tasks',
kind=FK.Array,
)
has_cloud_task_array = not FS.check(cloud_task_array)
if has_cloud_task_array and all(
cloud_task is not None for cloud_task in cloud_task_array
):
return cloud_task_array
return None
@bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
def task_infos(self) -> list[tdcloud.CloudTaskInfo | None] | None:
"""Retrieve the current cloud task information from the input socket.
If it can't be loaded, return None.
"""
if self.cloud_tasks is not None:
task_infos = [
tdcloud.TidyCloudTasks.task_info(cloud_task.task_id)
for cloud_task in self.cloud_tasks
]
if task_infos:
return task_infos
return None
@bl_cache.cached_bl_property(depends_on={'cloud_tasks'})
def task_progress(self) -> tuple[float | None, float | None] | None:
"""Retrieve current progress percent (in terms of time steps) and current field decay (normalized to max value).
Either entry can be None, denoting that they aren't yet available.
"""
if self.cloud_tasks is not None:
task_progress = [
cloud_task.get_running_info() for cloud_task in self.cloud_tasks
]
if task_progress:
return task_progress
return None
@bl_cache.cached_bl_property(depends_on={'task_progress'})
def total_progress_pct(self) -> float | None:
"""Retrieve current progress percent, averaged across all running tasks."""
if self.task_progress is not None and all(
progress[0] is not None and progress[1] is not None
for progress in self.task_progress
):
return sum([progress[0] for progress in self.task_progress]) / len(
self.task_progress
)
return None
@bl_cache.cached_bl_property(depends_on={'task_infos'})
def are_tasks_runnable(self) -> bool:
"""Checks whether all conditions are satisfied to be able to actually run a simulation."""
return self.task_infos is not None and all(
task_info is not None and task_info.status == 'draft'
for task_info in self.task_infos
)
####################
# - UI
####################
def draw_operators(self, _, layout):
"""Draw the button that runs the active simulation(s)."""
# Row: Run Sim Buttons
row = layout.row(align=True)
row.operator(
ct.OperatorType.NodeRunSimulation,
text='Run Sim',
)
def draw_info(self, _, layout):
"""Draw information about the running simulation."""
tdcloud.draw_cloud_status(layout)
# Cloud Task Info
if self.task_infos is not None and self.task_progress is not None:
for task_info, task_progress in zip(
self.task_infos, self.task_progress, strict=True
):
if self.task_infos is not None:
# Header
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Task Info')
# Task Run Progress
row = layout.row(align=True)
progress_pct = (
task_progress[0]
if task_progress is not None and task_progress[0] is not None
else 0.0
)
row.progress(
factor=progress_pct,
type='BAR',
text=f'{task_info.status.capitalize()}',
)
row.operator(
ct.OperatorType.NodeReloadTrackedTask,
text='',
icon='FILE_REFRESH',
)
# Task Information
box = layout.box()
split = box.split(factor=0.4)
col = split.column(align=False)
col.label(text='Status')
col = split.column(align=False)
col.alignment = 'RIGHT'
col.label(text=task_info.status.capitalize())
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Cloud Task',
kind=FK.Value,
# Loaded
props={'cloud_tasks', 'task_infos'},
)
def compute_cloud_task(self, props) -> str:
"""A single simulation data object, when there only is one."""
cloud_tasks = props['cloud_tasks']
task_infos = props['task_infos']
if (
cloud_tasks is not None
and len(cloud_tasks) == 1
and task_infos is not None
and task_infos[0].status == 'success'
):
return cloud_tasks[0]
return FS.FlowPending
####################
# - FlowKind.Array
####################
@events.computes_output_socket(
'Sim Datas',
kind=FK.Array,
# Loaded
props={'cloud_tasks', 'task_infos'},
)
def compute_cloud_tasks(self, props) -> list[tdcloud.CloudTask]:
"""All simulation data objects, for when there are more than one.
Generally part of the same batch.
"""
cloud_tasks = props['cloud_tasks']
task_infos = props['task_infos']
if (
cloud_tasks is not None
and len(cloud_tasks) > 1
and task_infos is not None
and all(task_info.status == 'success' for task_info in task_infos)
):
return cloud_tasks
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = [
RunSimulation,
ReloadTrackedTask,
FDTDSolverNode,
]
BL_NODES = {ct.NodeType.FDTDSolver: (ct.NodeCategory.MAXWELLSIM_SOLVERS)}

View File

@ -0,0 +1,18 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
BL_REGISTER = []
BL_NODES = {}

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `GaussianBeamSourceNode`."""
import typing as typ
import bpy
@ -29,6 +31,8 @@ from ... import managed_objs, sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class GaussianBeamSourceNode(base.MaxwellSimNode):
@ -57,13 +61,13 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([1, 1]),
default_value=sp.ImmutableMatrix([1, 1]),
),
'Waist Dist': sockets.ExprSocketDef(
mathtype=spux.MathType.Real,
@ -80,7 +84,7 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Angle,
default_value=sp.Matrix([0, 0]),
default_value=sp.ImmutableMatrix([0, 0]),
),
'Pol ∡': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Angle,
@ -106,129 +110,225 @@ class GaussianBeamSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
"""Draw choices of injection axis and direction."""
layout.prop(self, self.blfields['injection_axis'], expand=True)
layout.prop(self, self.blfields['injection_direction'], expand=True)
# layout.prop(self, self.blfields['num_freqs'], text='f Points')
## TODO: UI is a bit crowded already!
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Angled Source',
props={'sim_node_name', 'injection_axis', 'injection_direction', 'num_freqs'},
input_sockets={
'Temporal Shape',
'Center',
'Size',
'Waist Dist',
'Waist Radius',
'Spherical',
'Pol ∡',
},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
'Size': 'Tidy3DUnits',
'Waist Dist': 'Tidy3DUnits',
'Waist Radius': 'Tidy3DUnits',
'Spherical': 'Tidy3DUnits',
'Pol ∡': 'Tidy3DUnits',
kind=FK.Value,
# Loaded
outscks_kinds={
'Angled Source': {FK.Func, FK.Params},
},
)
def compute_source(self, props, input_sockets, unit_systems):
size_2d = input_sockets['Size']
size = {
ct.SimSpaceAxis.X: (0, *size_2d),
ct.SimSpaceAxis.Y: (size_2d[0], 0, size_2d[1]),
ct.SimSpaceAxis.Z: (*size_2d, 0),
}[props['injection_axis']]
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][FK.Func]
output_params = output_sockets['Structure'][FK.Params]
# Display the results
return td.GaussianBeam(
name=props['sim_node_name'],
center=input_sockets['Center'],
size=size,
source_time=input_sockets['Temporal Shape'],
num_freqs=props['num_freqs'],
direction=props['injection_direction'].plus_or_minus,
angle_theta=input_sockets['Spherical'][0],
angle_phi=input_sockets['Spherical'][1],
pol_angle=input_sockets['Pol ∡'],
waist_radius=input_sockets['Waist Radius'],
waist_distance=input_sockets['Waist Dist'],
## NOTE: Waist is place at this signed dist along neg. direction
if not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Angled Source',
kind=FK.Func,
# Loaded
props={'sim_node_name', 'injection_axis', 'injection_direction', 'num_freqs'},
inscks_kinds={
'Temporal Shape': FK.Func,
'Center': FK.Func,
'Size': FK.Func,
'Waist Dist': FK.Func,
'Waist Radius': FK.Func,
'Spherical': FK.Func,
'Pol ∡': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
'Waist Dist': ct.UNITS_TIDY3D,
'Waist Radius': ct.UNITS_TIDY3D,
'Spherical': ct.UNITS_TIDY3D,
'Pol ∡': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.GaussianBeam:
"""Compute a function that returns a gaussian beam object, given sufficient parameters."""
inj_dir = props['injection_axis']
def size(size_2d: tuple[float, float]) -> tuple[float, float, float]:
return {
ct.SimSpaceAxis.X: (0, *size_2d),
ct.SimSpaceAxis.Y: (size_2d[0], 0, size_2d[1]),
ct.SimSpaceAxis.Z: (*size_2d, 0),
}[inj_dir]
center = input_sockets['Center']
size_2d = input_sockets['Size']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
waist_radius = input_sockets['Waist Radius']
waist_dist = input_sockets['Waist Dist']
sim_node_name = props['sim_node_name']
num_freqs = props['num_freqs']
return (
center
| size_2d
| temporal_shape
| spherical
| pol_ang
| waist_radius
| waist_dist
).compose_within(
lambda els: td.GaussianBeam(
name=sim_node_name,
center=els[0],
size=size(els[1]),
source_time=els[2],
num_freqs=num_freqs,
direction=inj_dir.plus_or_minus,
angle_theta=els[3].item(0),
angle_phi=els[3].item(1),
pol_angle=els[4],
waist_radius=els[5],
waist_distance=els[6],
)
)
####################
# - Preview - Changes to Input Sockets
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=FK.Params,
# Loaded
inscks_kinds={
'Temporal Shape': FK.Func,
'Center': FK.Func,
'Size': FK.Func,
'Waist Dist': FK.Func,
'Waist Radius': FK.Func,
'Spherical': FK.Func,
'Pol ∡': FK.Func,
},
)
def compute_params(self, input_sockets) -> ct.ParamsFlow:
"""Propagate function parameters from inputs."""
center = input_sockets['Center']
size_2d = input_sockets['Size']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
waist_radius = input_sockets['Waist Radius']
waist_dist = input_sockets['Waist Dist']
return (
center
| size_2d
| temporal_shape
| spherical
| pol_ang
| waist_radius
| waist_dist
)
####################
# - Events: Preview
####################
@events.computes_output_socket(
'Angled Source',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
outscks_kinds={'Angled Source': ct.FlowKind.Func},
output_sockets_optional={'Angled Source'},
)
def compute_previews(self, props):
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
def compute_previews(self, props, output_sockets):
"""Update the preview state when the name or output function change."""
if not FS.check(output_sockets['Structure']):
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
@events.on_value_changed(
# Trigger
socket_name={
'Center',
'Size',
'Waist Dist',
'Waist Radius',
'Spherical',
'Pol ∡',
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
'Waist Dist': {FK.Func, FK.Params},
'Waist Radius': {FK.Func, FK.Params},
'Spherical': {FK.Func, FK.Params},
'Pol ∡': {FK.Func, FK.Params},
},
prop_name={'injection_axis', 'injection_direction'},
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'injection_axis', 'injection_direction'},
input_sockets={
'Temporal Shape',
'Center',
'Size',
'Waist Dist',
'Waist Radius',
'Spherical',
'Pol ∡',
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
'Waist Dist': {FK.Func, FK.Params},
'Waist Radius': {FK.Func, FK.Params},
'Spherical': {FK.Func, FK.Params},
'Pol ∡': {FK.Func, FK.Params},
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={
'Center': 'BlenderUnits',
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
'Waist Dist': ct.UNITS_BLENDER,
'Waist Radius': ct.UNITS_BLENDER,
'Spherical': ct.UNITS_BLENDER,
'Pol ∡': ct.UNITS_BLENDER,
},
)
def on_inputs_changed(self, managed_objs, props, input_sockets, unit_systems):
size_2d = input_sockets['Size']
def on_previewable_chnaged(self, managed_objs, props, input_sockets):
"""Update the preview when relevant inputs change."""
center = events.realize_preview(input_sockets['Center'])
size_2d = events.realize_preview(input_sockets['Size'])
spherical = events.realize_preview(input_sockets['Spherical'])
pol_ang = events.realize_preview(input_sockets['Pol ∡'])
waist_radius = events.realize_preview(input_sockets['Waist Radius'])
waist_dist = events.realize_preview(input_sockets['Waist Dist'])
# Retrieve Properties
inj_dir = props['injection_axis']
size = {
ct.SimSpaceAxis.X: sp.Matrix([0, *size_2d]),
ct.SimSpaceAxis.Y: sp.Matrix([size_2d[0], 0, size_2d[1]]),
ct.SimSpaceAxis.Z: sp.Matrix([*size_2d, 0]),
ct.SimSpaceAxis.X: sp.ImmutableMatrix([0, *size_2d]),
ct.SimSpaceAxis.Y: sp.ImmutableMatrix([size_2d[0], 0, size_2d[1]]),
ct.SimSpaceAxis.Z: sp.ImmutableMatrix([*size_2d, 0]),
}[props['injection_axis']]
# Push Input Values to GeoNodes Modifier
# Push Updated Values
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourceGaussianBeam),
'unit_system': unit_systems['BlenderUnits'],
'inputs': {
# Orientation
'Inj Axis': props['injection_axis'].axis,
'Direction': props['injection_direction'].true_or_false,
'theta': input_sockets['Spherical'][0],
'phi': input_sockets['Spherical'][1],
'Pol Angle': input_sockets['Pol ∡'],
'Inj Axis': inj_dir.axis,
'Direction': inj_dir.true_or_false,
'theta': spherical[0],
'phi': spherical[1],
'Pol Angle': pol_ang,
# Gaussian Beam
'Size': size,
'Waist Dist': input_sockets['Waist Dist'],
'Waist Radius': input_sockets['Waist Radius'],
'Waist Radius': waist_radius,
'Waist Dist': waist_dist,
},
},
location=input_sockets['Center'],
location=center,
)

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `PlaneWaveSourceNode`."""
import typing as typ
import bpy
@ -30,6 +32,11 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class PlaneWaveSourceNode(base.MaxwellSimNode):
"""An infinite-extent angled source simulating an plane wave with linear polarization.
@ -52,26 +59,26 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
# - Sockets
####################
input_sockets: typ.ClassVar = {
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(active_kind=FK.Func),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([0, 0, 0]),
mathtype=MT.Real,
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Spherical': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Angle,
default_value=sp.Matrix([0, 0]),
mathtype=MT.Real,
physical_type=PT.Angle,
default_value=sp.ImmutableMatrix([0, 0]),
),
'Pol ∡': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Angle,
physical_type=PT.Angle,
default_value=0,
),
}
output_sockets: typ.ClassVar = {
'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=ct.FlowKind.Func),
'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@ -88,6 +95,7 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
"""Draw choices of injection axis and direction."""
layout.prop(self, self.blfields['injection_axis'], expand=True)
layout.prop(self, self.blfields['injection_direction'], expand=True)
@ -96,175 +104,145 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Angled Source',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Angled Source'},
output_socket_kinds={'Angled Source': {ct.FlowKind.Func, ct.FlowKind.Params}},
outscks_kinds={'Angled Source': {FK.Func, FK.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Angled Source'][ct.FlowKind.Func]
output_params = output_sockets['Angled Source'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Angled Source'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Angled Source',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'sim_node_name', 'injection_axis', 'injection_direction'},
input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Spherical': ct.FlowKind.Func,
'Pol ∡': ct.FlowKind.Func,
inscks_kinds={
'Temporal Shape': FK.Func,
'Center': FK.Func,
'Spherical': FK.Func,
'Pol ∡': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Spherical': ct.UNITS_TIDY3D,
'Pol ∡': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> None:
"""Compute a lazy function for the plane wave source."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_spherical = not ct.FlowSignal.check(spherical)
has_pol_ang = not ct.FlowSignal.check(pol_ang)
name = props['sim_node_name']
inj_dir = props['injection_direction'].plus_or_minus
size = {
ct.SimSpaceAxis.X: (0, td.inf, td.inf),
ct.SimSpaceAxis.Y: (td.inf, 0, td.inf),
ct.SimSpaceAxis.Z: (td.inf, td.inf, 0),
}[props['injection_axis']]
if has_center and has_temporal_shape and has_spherical and has_pol_ang:
name = props['sim_node_name']
inj_dir = props['injection_direction'].plus_or_minus
size = {
ct.SimSpaceAxis.X: (0, td.inf, td.inf),
ct.SimSpaceAxis.Y: (td.inf, 0, td.inf),
ct.SimSpaceAxis.Z: (td.inf, td.inf, 0),
}[props['injection_axis']]
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| temporal_shape
| spherical.scale_to_unit_system(ct.UNITS_TIDY3D)
| pol_ang.scale_to_unit_system(ct.UNITS_TIDY3D)
).compose_within(
lambda els: td.PlaneWave(
name=name,
center=els[0].flatten().tolist(),
size=size,
source_time=els[1],
direction=inj_dir,
angle_theta=els[2][0].item(0),
angle_phi=els[2][1].item(0),
pol_angle=els[3],
)
return (center | temporal_shape | spherical | pol_ang).compose_within(
lambda els: td.PlaneWave(
name=name,
center=els[0].flatten().tolist(),
size=size,
source_time=els[1],
direction=inj_dir,
angle_theta=els[2].item(0),
angle_phi=els[2].item(1),
pol_angle=els[3],
)
return ct.FlowSignal.FlowPending
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Angled Source',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Spherical': ct.FlowKind.Params,
'Pol ∡': ct.FlowKind.Params,
inscks_kinds={
'Temporal Shape': FK.Params,
'Center': FK.Params,
'Spherical': FK.Params,
'Pol ∡': FK.Params,
},
)
def compute_params(self, input_sockets) -> None:
"""Compute the function parameters of the lazy function."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_spherical = not ct.FlowSignal.check(spherical)
has_pol_ang = not ct.FlowSignal.check(pol_ang)
if has_center and has_temporal_shape and has_spherical and has_pol_ang:
return center | temporal_shape | spherical | pol_ang
return ct.FlowSignal.FlowPending
return center | temporal_shape | spherical | pol_ang
####################
# - Preview - Changes to Input Sockets
####################
@events.computes_output_socket(
'Angled Source',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
output_sockets={'Angled Source'},
output_socket_kinds={'Angled Source': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
output_params = output_sockets['Angled Source']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
def compute_previews(self, props):
"""Mark the plane wave as participating in the 3D preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Spherical', 'Pol ∡'},
socket_name={
'Center': {FK.Func, FK.Params},
'Spherical': {FK.Func, FK.Params},
'Pol ∡': {FK.Func, FK.Params},
},
prop_name={'injection_axis', 'injection_direction'},
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'injection_axis', 'injection_direction'},
input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
output_sockets={'Angled Source'},
output_socket_kinds={'Angled Source': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Spherical': {FK.Func, FK.Params},
'Pol ∡': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
'Spherical': ct.UNITS_BLENDER,
'Pol ∡': ct.UNITS_BLENDER,
},
)
def on_previewable_changed(
self, managed_objs, props, input_sockets, output_sockets
):
center = input_sockets['Center']
spherical = input_sockets['Spherical']
pol_ang = input_sockets['Pol ∡']
output_params = output_sockets['Angled Source']
def on_previewable_changed(self, managed_objs, props, input_sockets):
"""Push changes in the inputs to the pol/center."""
center = events.realize_preview(input_sockets['Center'])
spherical = events.realize_preview(input_sockets['Spherical'])
pol_ang = events.realize_preview(input_sockets['Pol ∡'])
has_center = not ct.FlowSignal.check(center)
has_spherical = not ct.FlowSignal.check(spherical)
has_pol_ang = not ct.FlowSignal.check(pol_ang)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_center
and has_spherical
and has_pol_ang
and has_output_params
and not output_params.symbols
):
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourcePlaneWave),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Inj Axis': props['injection_axis'].axis,
'Direction': props['injection_direction'].true_or_false,
'theta': spherical[0],
'phi': spherical[1],
'Pol Angle': pol_ang,
},
# Push Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourcePlaneWave),
'inputs': {
'Inj Axis': props['injection_axis'].axis,
'Direction': props['injection_direction'].true_or_false,
'theta': spherical.item(0),
'phi': spherical.item(1),
'Pol Angle': pol_ang,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
},
location=center,
)
####################

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `PointDipoleSourceNode`."""
import typing as typ
import bpy
@ -30,8 +32,15 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class PointDipoleSourceNode(base.MaxwellSimNode):
"""A point dipole with E or H oriented linear polarization along an axis-aligned angle."""
node_type = ct.NodeType.PointDipoleSource
bl_label = 'Point Dipole Source'
use_sim_node_name = True
@ -43,16 +52,16 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length,
default_value=sp.Matrix([0, 0, 0]),
mathtype=MT.Real,
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Interpolate': sockets.BoolSocketDef(
default_value=True,
),
}
output_sockets: typ.ClassVar = {
'Source': sockets.MaxwellSourceSocketDef(active_kind=ct.FlowKind.Func),
'Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@ -68,6 +77,7 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
"""Draw choices of polarization direction."""
layout.prop(self, self.blfields['pol'], expand=True)
####################
@ -75,36 +85,33 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Source'},
output_socket_kinds={'Source': {ct.FlowKind.Func, ct.FlowKind.Params}},
output_socket_kinds={'Source': {FK.Func, FK.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Source'][ct.FlowKind.Func]
output_params = output_sockets['Source'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Source'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'pol'},
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Interpolate': ct.FlowKind.Func,
inscks_kinds={
'Temporal Shape': FK.Func,
'Center': FK.Func,
'Interpolate': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
},
)
def compute_func(self, props, input_sockets) -> td.Box:
@ -113,121 +120,95 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_interpolate = not ct.FlowSignal.check(interpolate)
pol = props['pol']
if has_temporal_shape and has_center and has_interpolate:
pol = props['pol']
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| temporal_shape
| interpolate
).compose_within(
enclosing_func=lambda els: td.PointDipole(
center=els[0].flatten().tolist(),
source_time=els[1],
interpolate=els[2],
polarization=pol.name,
)
return (center | temporal_shape | interpolate).compose_within(
lambda els: td.PointDipole(
center=els[0].flatten().tolist(),
source_time=els[1],
interpolate=els[2],
polarization=pol.name,
)
return ct.FlowSignal.FlowPending
)
return FS.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Temporal Shape', 'Center', 'Interpolate'},
input_socket_kinds={
'Temporal Shape': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Interpolate': ct.FlowKind.Params,
'Temporal Shape': FK.Params,
'Center': FK.Params,
'Interpolate': FK.Params,
},
)
def compute_params(
self,
input_sockets,
) -> td.PointDipole | ct.FlowSignal:
"""Compute the point dipole source, given that all inputs are non-symbolic."""
) -> td.PointDipole:
"""Compute the function parameters of the lazy function."""
center = input_sockets['Center']
temporal_shape = input_sockets['Temporal Shape']
interpolate = input_sockets['Interpolate']
has_center = not ct.FlowSignal.check(center)
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
has_interpolate = not ct.FlowSignal.check(interpolate)
if has_center and has_temporal_shape and has_interpolate:
return center | temporal_shape | interpolate
return ct.FlowSignal.FlowPending
return center | temporal_shape | interpolate
####################
# - Preview
####################
@events.computes_output_socket(
'Source',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
output_sockets={'Source'},
output_socket_kinds={'Source': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
output_params = output_sockets['Source']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
def compute_previews(self, props):
"""Mark the point dipole as participating in the 3D preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center'},
socket_name={'Center': {FK.Func, FK.Params}},
prop_name='pol',
run_on_init=True,
# Loaded
managed_objs={'modifier'},
props={'pol'},
input_sockets={'Center'},
output_sockets={'Source'},
output_socket_kinds={'Source': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
},
)
def on_previewable_changed(
self, managed_objs, props, input_sockets, output_sockets
) -> None:
def on_previewable_changed(self, managed_objs, props, input_sockets) -> None:
"""Push changes in the inputs to the pol/center."""
SFP = ct.SimFieldPols
center = input_sockets['Center']
output_params = output_sockets['Source']
center = events.realize_preview(input_sockets['Center'])
axis = {
SFP.Ex: 0,
SFP.Ey: 1,
SFP.Ez: 2,
SFP.Hx: 0,
SFP.Hy: 1,
SFP.Hz: 2,
}[props['pol']]
has_center = not ct.FlowSignal.check(center)
has_output_params = not ct.FlowSignal.check(output_params)
if has_center and has_output_params and not output_params.symbols:
axis = {
SFP.Ex: 0,
SFP.Ey: 1,
SFP.Ez: 2,
SFP.Hx: 0,
SFP.Hy: 1,
SFP.Hz: 2,
}[props['pol']]
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourcePointDipole),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Axis': axis,
},
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.SourcePointDipole),
'inputs': {
'Axis': axis,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
},
location=center,
)
####################

View File

@ -36,11 +36,15 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
PT = spux.PhysicalType
# Select Default Time Unit for Envelope
## -> Chosen to align with the default envelope_time_unit.
## -> This causes it to be correct from the start.
t_def = sim_symbols.t(spux.PhysicalType.Time.valid_units[0])
t_def = sim_symbols.t(PT.Time.valid_units[0])
class TemporalShapeNode(base.MaxwellSimNode):
@ -54,19 +58,19 @@ class TemporalShapeNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'μ Freq': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Freq,
physical_type=PT.Freq,
default_unit=spux.THz,
default_value=500,
),
'σ Freq': sockets.ExprSocketDef(
physical_type=spux.PhysicalType.Freq,
physical_type=PT.Freq,
default_unit=spux.THz,
default_value=200,
),
'max E': sockets.ExprSocketDef(
mathtype=spux.MathType.Complex,
physical_type=spux.PhysicalType.EField,
default_value=1 + 0j,
physical_type=PT.EField,
default_value=1,
),
'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5),
}
@ -77,8 +81,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
'Constant': {},
'Symbolic': {
't Range': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Range,
physical_type=spux.PhysicalType.Time,
active_kind=FK.Range,
physical_type=PT.Time,
default_unit=spu.picosecond,
default_min=0,
default_max=10,
@ -92,11 +96,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
}
output_sockets: typ.ClassVar = {
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(),
}
managed_obj_types: typ.ClassVar = {
'plot': managed_objs.ManagedBLImage,
'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(active_kind=FK.Func),
}
####################
@ -110,7 +110,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
"""Compute all valid time units."""
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(spux.PhysicalType.Time.valid_units)
for i, unit in enumerate(PT.Time.valid_units)
]
@bl_cache.cached_bl_property(depends_on={'active_envelope_time_unit'})
@ -178,8 +178,9 @@ class TemporalShapeNode(base.MaxwellSimNode):
)
def on_envelope_time_unit_changed(self, props) -> None:
"""Ensure the envelope expression's time symbol has the time unit defined by the node."""
active_socket_set = props['active_socket_set']
envelope_time_unit = props['envelope_time_unit']
active_socket_set = props['active_socket_set']
if active_socket_set == 'Symbolic':
bl_socket = self.inputs['Envelope']
wanted_t_sym = sim_symbols.t(envelope_time_unit)
@ -192,48 +193,40 @@ class TemporalShapeNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Temporal Shape'},
output_socket_kinds={'Temporal Shape': {ct.FlowKind.Func, ct.FlowKind.Params}},
outscks_kinds={'Temporal Shape': {FK.Func, FK.Params}},
)
def compute_domain_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute a single temporal shape."""
output_func = output_sockets['Temporal Shape'][ct.FlowKind.Func]
output_params = output_sockets['Temporal Shape'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Temporal Shape'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind: Func
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'active_socket_set'},
input_sockets={
'max E',
'μ Freq',
'σ Freq',
'Offset Time',
'Remove DC',
't Range',
'Envelope',
inscks_kinds={
'μ Freq': FK.Func,
'σ Freq': FK.Func,
'max E': FK.Func,
'Offset Time': FK.Func,
'Remove DC': FK.Value,
't Range': FK.Func,
'Envelope': {FK.Func, FK.Params},
},
input_socket_kinds={
'max E': ct.FlowKind.Func,
'μ Freq': ct.FlowKind.Func,
'σ Freq': ct.FlowKind.Func,
'Offset Time': ct.FlowKind.Func,
'Remove DC': ct.FlowKind.Value,
't Range': ct.FlowKind.Func,
'Envelope': {ct.FlowKind.Func, ct.FlowKind.Params},
input_sockets_optional={'Remove DC', 't Range', 'Envelope'},
scale_input_sockets={
'μ Freq': ct.UNITS_TIDY3D,
'σ Freq': ct.UNITS_TIDY3D,
'max E': ct.UNITS_TIDY3D,
't Range': ct.UNITS_TIDY3D,
},
)
def compute_temporal_shape_func(
@ -247,158 +240,114 @@ class TemporalShapeNode(base.MaxwellSimNode):
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
remove_dc = input_sockets['Remove DC']
t_range = input_sockets['t Range']
envelope = input_sockets['Envelope'][FK.Func]
envelope_params = input_sockets['Envelope'][FK.Params]
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_func = (
max_e.scale_to_unit_system(ct.UNITS_TIDY3D)
| mean_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| std_freq.scale_to_unit_system(ct.UNITS_TIDY3D)
| offset ## Already unitless
)
match props['active_socket_set']:
case 'Pulse':
remove_dc = input_sockets['Remove DC']
common_func = max_e | mean_freq | std_freq | offset
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Pulse' if not FS.check(remove_dc):
return common_func.compose_within(
lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
remove_dc_component=remove_dc,
),
)
has_remove_dc = not ct.FlowSignal.check(remove_dc)
case 'Constant':
return common_func.compose_within(
lambda els: td.ContinuousWave(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
),
)
if has_remove_dc:
return common_func.compose_within(
lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
remove_dc_component=remove_dc,
),
)
case 'Symbolic' if (
not FS.check(t_range)
and not FS.check(envelope)
and not FS.check(envelope_params)
and len(envelope_params.symbols) == 1
and next(iter(envelope_params.symbols)).physical_type is PT.Time
and any(sym.physical_type is PT.Time for sym in envelope_params.symbols)
):
envelope_time_unit = next(iter(envelope_params.symbols)).unit
case 'Constant':
return common_func.compose_within(
lambda els: td.GaussianPulse(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
# Deduce Partially Realized Envelope Function
## -> We need a pure-numerical function w/pre-realized stuff baked in.
## -> 'generate_realizer' does this for us.
envelope_realizer = envelope.generate_realizer(envelope_params)
# Compose w/Envelope Function
## -> First, the numerical time values must be converted.
## -> This ensures that the raw array is compatible w/the envelope.
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
## -> Because of the checks, we've guaranteed that all this is correct.
return (
common_func
| t_range
| t_range.scale_to_unit(envelope_time_unit).compose_within(
lambda t: envelope_realizer(t)
)
).compose_within(
lambda els: td.CustomSourceTime(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
source_time_dataset=td_TimeDataset(
values=td_TimeDataArray(
els[5], coords={'t': np.array(els[4])}
)
),
)
)
case 'Symbolic':
t_range = input_sockets['t Range']
envelope = input_sockets['Envelope'][ct.FlowKind.Func]
envelope_params = input_sockets['Envelope'][ct.FlowKind.Params]
has_t_range = not ct.FlowSignal.check(t_range)
has_envelope = not ct.FlowSignal.check(envelope)
has_envelope_params = not ct.FlowSignal.check(envelope_params)
if (
has_t_range
and has_envelope
and has_envelope_params
and len(envelope_params.symbols) == 1
## TODO: Allow unrealized envelope symbols
and any(
sym.physical_type is spux.PhysicalType.Time
for sym in envelope_params.symbols
)
):
envelope_time_unit = next(
sym.unit
for sym in envelope_params.symbols
if sym.physical_type is spux.PhysicalType.Time
)
# Deduce Partially Realized Envelope Function
## -> We need a pure-numerical function w/pre-realized stuff baked in.
## -> 'realize_partial' does this for us.
envelope_realizer = envelope.realize_partial(envelope_params)
# Compose w/Envelope Function
## -> First, the numerical time values must be converted.
## -> This ensures that the raw array is compatible w/the envelope.
## -> Then, we can compose w/the purely numerical 'envelope_realizer'.
## -> Because of the checks, we've guaranteed that all this is correct.
return (
common_func ## 1 | freq0, 2 | fwidth, 3 | offset
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D) ## 4
| t_range.scale_to_unit(envelope_time_unit).compose_within(
lambda t: envelope_realizer(t)
) ## 5
).compose_within(
lambda els: td.CustomSourceTime(
amplitude=complex(els[0]).real,
phase=complex(els[0]).imag,
freq0=els[1],
fwidth=els[2],
offset=els[3],
source_time_dataset=td_TimeDataset(
values=td_TimeDataArray(
els[5], coords={'t': np.array(els[4])}
)
),
)
)
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - FlowKind: Params
####################
@events.computes_output_socket(
'Temporal Shape',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
props={'active_socket_set', 'envelope_time_unit'},
input_sockets={
'max E',
'μ Freq',
'σ Freq',
'Offset Time',
't Range',
},
input_socket_kinds={
'max E': ct.FlowKind.Params,
'μ Freq': ct.FlowKind.Params,
'σ Freq': ct.FlowKind.Params,
'Offset Time': ct.FlowKind.Params,
't Range': ct.FlowKind.Params,
props={'active_socket_set'},
inscks_kinds={
'μ Freq': FK.Params,
'σ Freq': FK.Params,
'max E': FK.Params,
'Offset Time': FK.Params,
't Range': FK.Params,
},
input_sockets_optional={'t Range'},
)
def compute_temporal_shape_params(
self,
props,
input_sockets,
) -> td.GaussianPulse:
def compute_temporal_shape_params(self, props, input_sockets) -> td.GaussianPulse:
"""Compute a single temporal shape from non-parameterized inputs."""
mean_freq = input_sockets['μ Freq']
std_freq = input_sockets['σ Freq']
max_e = input_sockets['max E']
offset = input_sockets['Offset Time']
has_mean_freq = not ct.FlowSignal.check(mean_freq)
has_std_freq = not ct.FlowSignal.check(std_freq)
has_max_e = not ct.FlowSignal.check(max_e)
has_offset = not ct.FlowSignal.check(offset)
t_range = input_sockets['t Range']
if has_mean_freq and has_std_freq and has_max_e and has_offset:
common_params = max_e | mean_freq | std_freq | offset
match props['active_socket_set']:
case 'Pulse' | 'Constant':
return common_params
common_params = max_e | mean_freq | std_freq | offset
match props['active_socket_set']:
case 'Pulse' | 'Constant':
return common_params
case 'Symbolic':
t_range = input_sockets['t Range']
has_t_range = not ct.FlowSignal.check(t_range)
if has_t_range:
return common_params | t_range | t_range
return ct.FlowSignal.FlowPending
case 'Symbolic' if not FS.check(t_range):
return common_params | t_range | t_range
return FS.FlowPending
####################

View File

@ -14,14 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `GeoNodesStructureNode`."""
import functools
import typing as typ
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from ... import bl_socket_map, managed_objs, sockets
from ... import contracts as ct
@ -29,8 +32,14 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class GeoNodesStructureNode(base.MaxwellSimNode):
"""A generic mesh structure defined by an arbitrary Geometry Nodes tree."""
node_type = ct.NodeType.GeoNodesStructure
bl_label = 'GeoNodes Structure'
use_sim_node_name = True
@ -40,44 +49,106 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
####################
input_sockets: typ.ClassVar = {
'GeoNodes': sockets.BlenderGeoNodesSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
}
output_sockets: typ.ClassVar = {
'Structure': sockets.MaxwellStructureSocketDef(),
'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
'modifier': managed_objs.ManagedBLModifier,
'preview_mesh': managed_objs.ManagedBLModifier,
'mesh': managed_objs.ManagedBLModifier,
}
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
input_sockets={'Medium'},
managed_objs={'modifier'},
kind=FK.Value,
# Loaded
outscks_kinds={
'Structure': {FK.Func, FK.Params},
},
)
def compute_structure(
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
value = events.realize_known(output_sockets['Structure'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=FK.Func,
# Loaded
props={'sim_node_name'},
managed_objs={'mesh'},
inscks_kinds={
'GeoNodes': FK.Value,
'Medium': FK.Func,
'Center': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
},
all_loose_input_sockets=True,
loose_input_sockets_kind=FK.Func,
scale_loose_input_sockets=ct.UNITS_BLENDER,
)
def compute_func(
self,
input_sockets,
managed_objs,
props,
input_sockets,
loose_input_sockets,
) -> td.Structure:
"""Computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`."""
"""Lazily computes a triangle-mesh based Tidy3D structure, by manually copying mesh data from Blender to a `td.TriangleMesh`."""
## TODO: mesh_as_arrays might not take the Center into account.
## - Alternatively, Tidy3D might have a way to transform?
mesh_as_arrays = managed_objs['modifier'].mesh_as_arrays
return td.Structure(
geometry=td.TriangleMesh.from_vertices_faces(
mesh_as_arrays['verts'],
mesh_as_arrays['faces'],
),
medium=input_sockets['Medium'],
mesh = managed_objs['mesh']
geonodes = input_sockets['GeoNodes']
medium = input_sockets['Medium']
center = input_sockets['Center']
sim_node_name = props['sim_node_name']
gn_inputs = list(loose_input_sockets.keys())
def verts_faces(els: tuple[typ.Any]) -> dict[str, typ.Any]:
# Push Realized Values to Managed Mesh
mesh.bl_modifier(
'NODES',
{
'node_group': geonodes,
'inputs': dict(zip(gn_inputs, els[2:], strict=True)),
},
location=els[1],
)
# Extract Vert/Face Data
mesh_as_arrays = mesh.mesh_as_arrays
return (mesh_as_arrays['verts'], mesh_as_arrays['faces'])
loose_sck_values = functools.reduce(
lambda a, b: a | b, loose_input_sockets.values()
)
return (medium | center | loose_sck_values).compose_within(
lambda els: td.Structure(
name=sim_node_name,
geometry=td.TriangleMesh.from_vertices_faces(
*verts_faces(els),
),
medium=els[0],
)
)
####################
@ -99,61 +170,73 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
# - Events: Swapped GN Node Tree
####################
@events.on_value_changed(
socket_name={'GeoNodes'},
socket_name={
'GeoNodes': FK.Value,
},
# Loaded
managed_objs={'modifier'},
input_sockets={'GeoNodes'},
inscks_kinds={
'GeoNodes': FK.Value,
},
)
def on_input_changed(
self,
managed_objs,
input_sockets,
) -> None:
"""Declares new loose input sockets in response to a new GeoNodes tree (if any)."""
geonodes = input_sockets['GeoNodes']
has_geonodes = not ct.FlowSignal.check(geonodes) and geonodes is not None
"""Synchronizes the GeoNodes tree input sockets to the loose input sockets of this node.
if has_geonodes:
Utilizes `bl_socket_map.sockets_from_geonodes` to generate, primarily, `ExprSocketDef`s for use with `self.loose_input_sockets.
"""
geonodes = input_sockets['GeoNodes']
if geonodes is not None:
# Fill the Loose Input Sockets
## -> The SocketDefs contain the default values from the interface.
log.info(
log.debug(
'Initializing GeoNodes Structure Node "%s" from GeoNodes Group "%s"',
self.bl_label,
str(geonodes),
)
self.loose_input_sockets = bl_socket_map.sockets_from_geonodes(geonodes)
## -> The loose socket creation triggers 'on_input_socket_changed'
elif self.loose_input_sockets:
self.loose_input_sockets = {}
managed_objs['modifier'].free()
####################
# - Events: Preview
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews(self, props):
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
"""Mark the box structure as participating in the preview."""
sim_node_name = props['sim_node_name']
return ct.PreviewsFlow(bl_object_names={sim_node_name + '_1'})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'GeoNodes'}, ## MUST run after on_input_changed
socket_name={
'GeoNodes': FK.Value,
'Center': {FK.Func, FK.Params},
},
any_loose_input_socket=True,
run_on_init=True,
# Loaded
managed_objs={'modifier'},
input_sockets={'Center', 'GeoNodes'},
managed_objs={'preview_mesh'},
inscks_kinds={
'GeoNodes': FK.Value,
'Center': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
},
all_loose_input_sockets=True,
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
scale_input_sockets={'Center': 'BlenderUnits'},
loose_input_sockets_kind={FK.Func, FK.Params},
scale_loose_input_sockets=ct.UNITS_BLENDER,
)
def on_input_socket_changed(
self, managed_objs, input_sockets, loose_input_sockets, unit_systems
self, managed_objs, input_sockets, loose_input_sockets
) -> None:
"""Pushes any change in GeoNodes-bound input sockets to the GeoNodes modifier.
@ -163,19 +246,16 @@ class GeoNodesStructureNode(base.MaxwellSimNode):
Also pushes the `Center:Value` socket to govern the object's center in 3D space.
"""
geonodes = input_sockets['GeoNodes']
has_geonodes = not ct.FlowSignal.check(geonodes) and geonodes is not None
center = events.realize_preview(input_sockets['Center'])
if has_geonodes:
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': geonodes,
'inputs': loose_input_sockets,
'unit_system': unit_systems['BlenderUnits'],
},
location=input_sockets['Center'],
)
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': geonodes,
'inputs': loose_input_sockets,
},
location=center,
)
####################

View File

@ -14,16 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `BoxStructureNode`."""
import typing as typ
import bpy
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
@ -32,9 +32,13 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class BoxStructureNode(base.MaxwellSimNode):
"""A generic, differentiable box structure with configurable size and center."""
"""A generic box structure with configurable size and center."""
node_type = ct.NodeType.BoxStructure
bl_label = 'Box Structure'
@ -48,198 +52,141 @@ class BoxStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.nanometer,
default_value=sp.Matrix([500, 500, 500]),
default_value=sp.ImmutableMatrix([500, 500, 500]),
abs_min=0.001,
),
}
output_sockets: typ.ClassVar = {
'Structure': sockets.MaxwellStructureSocketDef(active_kind=ct.FlowKind.Func),
'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
'modifier': managed_objs.ManagedBLModifier,
}
####################
# - Properties
####################
differentiable: bool = bl_cache.BLField(False)
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
layout.prop(
self,
self.blfields['differentiable'],
text='Differentiable',
toggle=True,
)
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Structure'},
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
outscks_kinds={
'Structure': {FK.Func, FK.Params},
},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][ct.FlowKind.Func]
output_params = output_sockets['Structure'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Structure'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
props={'differentiable'},
input_sockets={'Medium', 'Center', 'Size'},
input_socket_kinds={
'Medium': ct.FlowKind.Func,
'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func,
inscks_kinds={
'Medium': FK.Func,
'Center': FK.Func,
'Size': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Size': ct.UNITS_TIDY3D,
},
)
def compute_structure_func(self, props, input_sockets) -> td.Box:
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
def compute_func(self, input_sockets) -> td.Box:
"""Compute a function, producing a box structure from the input parameters."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_size and has_medium:
differentiable = props['differentiable']
if differentiable:
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: tdadj.JaxStructure(
geometry=tdadj.JaxBox(
center_jax=tuple(els[0].flatten()),
size_jax=tuple(els[1].flatten()),
),
medium=els[2],
),
supports_jax=True,
)
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: td.Structure(
geometry=td.Box(
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
),
medium=els[2],
return (center | size | medium).compose_within(
lambda els: td.Structure(
geometry=td.Box(
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending
medium=els[2],
),
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Medium', 'Center', 'Size'},
input_socket_kinds={
'Medium': ct.FlowKind.Params,
'Center': ct.FlowKind.Params,
'Size': ct.FlowKind.Params,
inscks_kinds={
'Medium': FK.Params,
'Center': FK.Params,
'Size': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
"""Aggregate the function parameters needed by the box."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_size and has_medium:
return center | size | medium
return ct.FlowSignal.FlowPending
return center | size | medium
####################
# - Events: Preview
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
"""Mark the managed preview object when recursively linked to a viewer."""
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
def compute_previews(self, props):
"""Mark the box structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Size'},
socket_name={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
run_on_init=True,
# Loaded
managed_objs={'modifier'},
input_sockets={'Center', 'Size'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
'Size': ct.UNITS_BLENDER,
},
)
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
center = input_sockets['Center']
size = input_sockets['Size']
output_params = output_sockets['Structure']
def on_previewable_changed(self, managed_objs, input_sockets) -> None:
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
size = events.realize_preview(input_sockets['Size'])
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_output_params = not ct.FlowSignal.check(output_params)
if has_center and has_size and has_output_params and not output_params.symbols:
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Size': size,
},
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
'inputs': {
'Size': size,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
},
location=center,
)
####################

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `CylinderStructureNode`."""
import typing as typ
import sympy as sp
@ -21,8 +23,8 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets
@ -30,6 +32,10 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class CylinderStructureNode(base.MaxwellSimNode):
"""A generic cylinder structure with configurable radius and height."""
@ -46,7 +52,7 @@ class CylinderStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Radius': sockets.ExprSocketDef(
default_unit=spu.nanometer,
@ -58,7 +64,7 @@ class CylinderStructureNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
'Structure': sockets.MaxwellStructureSocketDef(active_kind=ct.FlowKind.Func),
'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@ -70,36 +76,35 @@ class CylinderStructureNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Value,
kind=FK.Value,
# Loaded
output_sockets={'Structure'},
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
output_socket_kinds={'Structure': {FK.Func, FK.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][ct.FlowKind.Func]
output_params = output_sockets['Structure'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
value = events.realize_known(output_sockets['Structure'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Func,
kind=FK.Func,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
input_socket_kinds={
'Center': ct.FlowKind.Func,
'Radius': ct.FlowKind.Func,
'Height': ct.FlowKind.Func,
'Medium': ct.FlowKind.Func,
inscks_kinds={
'Center': FK.Func,
'Radius': FK.Func,
'Height': FK.Func,
'Medium': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Radius': ct.UNITS_TIDY3D,
'Height': ct.UNITS_TIDY3D,
},
)
def compute_func(self, input_sockets) -> td.Box:
@ -109,119 +114,90 @@ class CylinderStructureNode(base.MaxwellSimNode):
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| radius.scale_to_unit_system(ct.UNITS_TIDY3D)
| height.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: td.Structure(
geometry=td.Cylinder(
center=els[0].flatten().tolist(),
radius=els[1],
length=els[2],
),
medium=els[3],
)
return (center | radius | height | medium).compose_within(
lambda els: td.Structure(
geometry=td.Cylinder(
center=els[0].flatten().tolist(),
radius=els[1],
length=els[2],
),
medium=els[3],
)
return ct.FlowSignal.FlowPending
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Params,
kind=FK.Params,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
input_socket_kinds={
'Center': ct.FlowKind.Params,
'Radius': ct.FlowKind.Params,
'Height': ct.FlowKind.Params,
'Medium': ct.FlowKind.Params,
inscks_kinds={
'Center': FK.Params,
'Radius': FK.Params,
'Height': FK.Params,
'Medium': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
"""Aggregate the function parameters needed by the cylinder."""
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return center | radius | height | medium
return ct.FlowSignal.FlowPending
return center | radius | height | medium
####################
# - Preview
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_previews(self, props, output_sockets):
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_params and not output_params.symbols:
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
return ct.PreviewsFlow()
def compute_previews(self, props):
"""Mark the cylinder structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Radius', 'Medium', 'Height'},
run_on_init=True,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
managed_objs={'modifier'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Radius': {FK.Func, FK.Params},
'Medium': {FK.Func, FK.Params},
'Height': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': ct.UNITS_BLENDER,
'Radius': ct.UNITS_BLENDER,
'Height': ct.UNITS_BLENDER,
},
)
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
output_params = output_sockets['Structure']
def on_previewable_changed(self, managed_objs, input_sockets) -> None:
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
radius = events.realize_preview(input_sockets['Radius'])
height = events.realize_preview(input_sockets['Height'])
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_center
and has_radius
and has_height
and has_output_params
and not output_params.symbols
):
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveCylinder),
'inputs': {
'Radius': radius,
'Height': height,
},
'unit_system': ct.UNITS_BLENDER,
# Push Loose Input Values to GeoNodes Modifier
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveCylinder),
'inputs': {
'Radius': radius,
'Height': height,
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
)
'unit_system': ct.UNITS_BLENDER,
},
location=center,
)
####################

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `SphereStructureNode`."""
import typing as typ
import sympy as sp
@ -21,8 +23,8 @@ import sympy.physics.units as spu
import tidy3d as td
from blender_maxwell.assets.geonodes import GeoNodes, import_geonodes
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
from .... import contracts as ct
from .... import managed_objs, sockets
@ -30,8 +32,14 @@ from ... import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class SphereStructureNode(base.MaxwellSimNode):
"""A generic sphere structure with configurable size and center."""
node_type = ct.NodeType.SphereStructure
bl_label = 'Sphere Structure'
use_sim_node_name = True
@ -44,7 +52,7 @@ class SphereStructureNode(base.MaxwellSimNode):
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,
default_value=sp.Matrix([0, 0, 0]),
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Radius': sockets.ExprSocketDef(
default_unit=spu.nanometer,
@ -52,7 +60,7 @@ class SphereStructureNode(base.MaxwellSimNode):
),
}
output_sockets: typ.ClassVar = {
'Structure': sockets.MaxwellStructureSocketDef(),
'Structure': sockets.MaxwellStructureSocketDef(active_kind=FK.Func),
}
managed_obj_types: typ.ClassVar = {
@ -60,70 +68,122 @@ class SphereStructureNode(base.MaxwellSimNode):
}
####################
# - Outputs
# - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
input_sockets={'Center', 'Radius', 'Medium'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
'Radius': 'Tidy3DUnits',
kind=FK.Value,
# Loaded
outscks_kinds={
'Structure': {FK.Func, FK.Params},
},
)
def compute_structure(self, input_sockets, unit_systems) -> td.Box:
return td.Structure(
geometry=td.Sphere(
radius=input_sockets['Radius'],
center=input_sockets['Center'],
def compute_value(self, output_sockets) -> ct.ParamsFlow | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
value = events.realize_known(output_sockets['Structure'])
if value is not None:
return value
return FS.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=FK.Func,
# Loaded
inscks_kinds={
'Medium': FK.Func,
'Center': FK.Func,
'Radius': FK.Func,
},
scale_input_sockets={
'Center': ct.UNITS_TIDY3D,
'Radius': ct.UNITS_TIDY3D,
},
)
def compute_func(self, input_sockets) -> td.Box:
"""Compute a function, producing a box structure from the input parameters."""
center = input_sockets['Center']
radius = input_sockets['Radius']
medium = input_sockets['Medium']
return (center | radius | medium).compose_within(
lambda els: td.Structure(
geometry=td.Sphere(
center=els[0].flatten().tolist(),
radius=els[1],
),
medium=els[2],
),
medium=input_sockets['Medium'],
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=FK.Params,
# Loaded
inscks_kinds={
'Medium': FK.Params,
'Center': FK.Params,
'Radius': FK.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
"""Aggregate the function parameters needed by the sphere."""
center = input_sockets['Center']
radius = input_sockets['Radius']
medium = input_sockets['Medium']
return center | radius | medium
####################
# - Preview
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Previews,
kind=FK.Previews,
# Loaded
props={'sim_node_name'},
)
def compute_previews(self, props):
"""Mark the sphere structure as participating in the preview."""
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
@events.on_value_changed(
# Trigger
socket_name={'Center', 'Radius'},
socket_name={
'Center': {FK.Func, FK.Params},
'Size': {FK.Func, FK.Params},
},
run_on_init=True,
# Loaded
input_sockets={'Center', 'Radius'},
managed_objs={'modifier'},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
inscks_kinds={
'Center': {FK.Func, FK.Params},
'Radius': {FK.Func, FK.Params},
},
scale_input_sockets={
'Center': 'BlenderUnits',
'Center': ct.UNITS_BLENDER,
'Radius': ct.UNITS_BLENDER,
},
)
def on_inputs_changed(
self,
managed_objs,
input_sockets,
unit_systems,
):
modifier = managed_objs['modifier']
unit_system = unit_systems['BlenderUnits']
def on_previewable_changed(self, managed_objs, input_sockets):
"""Push changes in the inputs to the center / size."""
center = events.realize_preview(input_sockets['Center'])
radius = events.realize_preview(input_sockets['Radius'])
# Push Loose Input Values to GeoNodes Modifier
modifier.bl_modifier(
managed_objs['modifier'].bl_modifier(
'NODES',
{
'node_group': import_geonodes(GeoNodes.StructurePrimitiveSphere),
'inputs': {
'Radius': input_sockets['Radius'],
'Radius': radius,
},
'unit_system': unit_system,
},
location=input_sockets['Center'],
location=center,
)

View File

@ -14,18 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# from . import math
from . import combine
# from . import separate
from . import combine, view_text, wave_constant
BL_REGISTER = [
# *math.BL_REGISTER,
*wave_constant.BL_REGISTER,
*combine.BL_REGISTER,
# *separate.BL_REGISTER,
*view_text.BL_REGISTER,
]
BL_NODES = {
# **math.BL_NODES,
**wave_constant.BL_NODES,
**combine.BL_NODES,
# **separate.BL_NODES,
**view_text.BL_NODES,
}

View File

@ -20,12 +20,17 @@ import typing as typ
import bpy
import sympy as sp
from blender_maxwell.utils import bl_cache
from blender_maxwell.utils import bl_cache, logger
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class CombineNode(base.MaxwellSimNode):
"""Combine single objects (ex. Source, Monitor, Structure) into a list."""
@ -44,17 +49,17 @@ class CombineNode(base.MaxwellSimNode):
output_socket_sets: typ.ClassVar = {
'Sources': {
'Sources': sockets.MaxwellSourceSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
},
'Structures': {
'Structures': sockets.MaxwellStructureSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
},
'Monitors': {
'Monitors': sockets.MaxwellMonitorSocketDef(
active_kind=ct.FlowKind.Array,
active_kind=FK.Array,
),
},
}
@ -63,14 +68,14 @@ class CombineNode(base.MaxwellSimNode):
# - Properties
####################
concatenate_first: bool = bl_cache.BLField(False)
value_or_func: ct.FlowKind = bl_cache.BLField(
value_or_func: FK = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_func(),
)
def _value_or_func(self):
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Func])
for i, flow_kind in enumerate([FK.Value, FK.Func])
]
####################
@ -79,7 +84,7 @@ class CombineNode(base.MaxwellSimNode):
def draw_props(self, _, layout: bpy.types.UILayout):
layout.prop(self, self.blfields['value_or_func'], text='')
if self.value_or_func is ct.FlowKind.Value:
if self.value_or_func is FK.Value:
layout.prop(
self,
self.blfields['concatenate_first'],
@ -91,15 +96,17 @@ class CombineNode(base.MaxwellSimNode):
# - Events
####################
@events.on_value_changed(
any_loose_input_socket=True,
prop_name={'active_socket_set', 'concatenate_first', 'value_or_func'},
any_loose_input_socket=True,
run_on_init=True,
# Loaded
props={'active_socket_set', 'concatenate_first', 'value_or_func'},
)
def on_inputs_changed(self, props) -> None:
"""Always create one extra loose input socket."""
"""Always create one extra loose input socket off the end of the last linked loose socket."""
active_socket_set = props['active_socket_set']
concatenate_first = props['concatenate_first']
flow_kind = props['value_or_func']
# Deduce SocketDef
## -> Cheat by retrieving the class from the output sockets.
@ -121,14 +128,11 @@ class CombineNode(base.MaxwellSimNode):
new_amount = current_filled + 1
# Deduce SocketDef | Current Amount
concatenate_first = props['concatenate_first']
flow_kind = props['value_or_func']
self.loose_input_sockets = {
'#0': SocketDef(
active_kind=flow_kind
if flow_kind is ct.FlowKind.Func or not concatenate_first
else ct.FlowKind.Array
if flow_kind is FK.Func or not concatenate_first
else FK.Array
)
} | {f'#{i}': SocketDef(active_kind=flow_kind) for i in range(1, new_amount)}
@ -138,29 +142,25 @@ class CombineNode(base.MaxwellSimNode):
def compute_combined(
self,
loose_input_sockets,
input_flow_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.Func],
output_flow_kind: typ.Literal[ct.FlowKind.Array, ct.FlowKind.Func],
) -> list[typ.Any] | ct.FuncFlow | ct.FlowSignal:
input_flow_kind: typ.Literal[FK.Value, FK.Func],
output_flow_kind: typ.Literal[FK.Array, FK.Func],
) -> list[typ.Any] | ct.FuncFlow | FS:
"""Correctly compute the combined loose input sockets, given a valid combination of input and output `FlowKind`s.
If there is no output, or the flows aren't compatible, return `FlowPending`.
"""
match (input_flow_kind, output_flow_kind):
case (ct.FlowKind.Value, ct.FlowKind.Array):
case (FK.Value, FK.Array):
value_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
inp for inp in loose_input_sockets.values() if not FS.check(inp)
]
if value_flows:
return value_flows
return ct.FlowSignal.FlowPending
return FS.FlowPending
case (ct.FlowKind.Func, ct.FlowKind.Func):
case (FK.Func, FK.Func):
func_flows = [
inp
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
inp for inp in loose_input_sockets.values() if not FS.check(inp)
]
if len(func_flows) > 1:
@ -171,63 +171,59 @@ class CombineNode(base.MaxwellSimNode):
if len(func_flows) == 1:
return func_flows[0].compose_within(lambda el: [el])
return ct.FlowSignal.FlowPending
return FS.FlowPending
case (ct.FlowKind.Func, ct.FlowKind.Params):
case (FK.Func, FK.Params):
params_flows = [
params_flow
for inp_sckname in self.inputs.keys() # noqa: SIM118
if not ct.FlowSignal.check(
params_flow := self._compute_input(
inp_sckname, kind=ct.FlowKind.Params
)
if not FS.check(
params_flow := self._compute_input(inp_sckname, kind=FK.Params)
)
]
if params_flows:
return functools.reduce(lambda a, b: a | b, params_flows)
return ct.FlowSignal.FlowPending
return FS.FlowPending
return ct.FlowSignal.FlowPending
return FS.FlowPending
####################
# - Output: Sources
####################
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Array,
kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_array(
self, props, loose_input_sockets
) -> list[typ.Any] | ct.FlowSignal:
def compute_sources_array(self, props, loose_input_sockets) -> list[typ.Any] | FS:
"""Compute sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Func,
kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Sources',
kind=ct.FlowKind.Params,
kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_sources_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) sources."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
loose_input_sockets, props['value_or_func'], FK.Params
)
####################
@ -235,38 +231,38 @@ class CombineNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Array,
kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Func,
kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Structures',
kind=ct.FlowKind.Params,
kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_structures_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
loose_input_sockets, props['value_or_func'], FK.Params
)
####################
@ -274,38 +270,38 @@ class CombineNode(base.MaxwellSimNode):
####################
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Array,
kind=FK.Array,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_array(self, props, loose_input_sockets) -> sp.Expr:
"""Compute monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Array
loose_input_sockets, props['value_or_func'], FK.Array
)
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Func,
kind=FK.Func,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_func(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) monitors."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Func
loose_input_sockets, props['value_or_func'], FK.Func
)
@events.computes_output_socket(
'Monitors',
kind=ct.FlowKind.Params,
kind=FK.Params,
all_loose_input_sockets=True,
props={'value_or_func'},
)
def compute_monitors_params(self, props, loose_input_sockets) -> list[typ.Any]:
"""Compute (lazy) structures."""
return self.compute_combined(
loose_input_sockets, props['value_or_func'], ct.FlowKind.Params
loose_input_sockets, props['value_or_func'], FK.Params
)
@ -315,4 +311,4 @@ class CombineNode(base.MaxwellSimNode):
BL_REGISTER = [
CombineNode,
]
BL_NODES = {ct.NodeType.Combine: (ct.NodeCategory.MAXWELLSIM_SIMS)}
BL_NODES = {ct.NodeType.Combine: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}

View File

@ -0,0 +1,127 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
class ConsoleViewOperator(bpy.types.Operator):
bl_idname = 'blender_maxwell.console_view_operator'
bl_label = 'View Plots'
@classmethod
def poll(cls, _: bpy.types.Context):
return True
def execute(self, context):
node = context.node
node.print_data_to_console()
return {'FINISHED'}
class RefreshPlotViewOperator(bpy.types.Operator):
bl_idname = 'blender_maxwell.refresh_plot_view_operator'
bl_label = 'Refresh Plots'
@classmethod
def poll(cls, _: bpy.types.Context):
return True
def execute(self, context):
node = context.node
node.on_changed_plot_preview()
return {'FINISHED'}
####################
# - Node
####################
class ViewTextNode(base.MaxwellSimNode):
node_type = ct.NodeType.ViewText
bl_label = 'View Text'
# use_sim_node_name = True
input_sockets: typ.ClassVar = {
'Text': sockets.StringSocketDef(),
}
####################
# - Properties
####################
push_live: bool = bl_cache.BLField(True)
####################
# - Properties: Computed FlowKinds
####################
@events.on_value_changed(
socket_name={'Text': FK.Value},
prop_name={'push_live', 'sim_node_name'},
# Loaded
inscks_kinds={'Text': FK.Value},
input_sockets_optional={'Text'},
props={'push_live', 'sim_node_name'},
)
def on_text_changed(self, props, input_sockets) -> None:
sim_node_name = props['sim_node_name']
push_live = props['push_live']
if push_live:
if bpy.data.texts.get(sim_node_name) is None:
bpy.data.texts.new(sim_node_name)
bl_text = bpy.data.texts[sim_node_name]
bl_text.clear()
text = input_sockets['Text']
has_text = not FS.check(text)
if has_text:
bl_text.write(text)
####################
# - UI
####################
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
row = layout.row(align=True)
row.prop(self, self.blfields['push_live'], text='Write Live', toggle=True)
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout):
pass
####################
# - Blender Registration
####################
BL_REGISTER = [
ViewTextNode,
]
BL_NODES = {ct.NodeType.ViewText: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}

View File

@ -22,7 +22,7 @@ import bpy
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger, sci_constants
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
@ -31,9 +31,13 @@ from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
class WaveConstantNode(base.MaxwellSimNode):
"""Translates vaccum wavelength/frequency into both, either as a scalar, or as a memory-efficient uniform range of values.
"""Translates vacuum wavelength / non-angular frequency into both, either as a scalar, or as a memory-efficient uniform range of values.
Socket Sets:
Wavelength: Input a wavelength (range) to produce both wavelength/frequency (ranges).
@ -100,7 +104,7 @@ class WaveConstantNode(base.MaxwellSimNode):
run_on_init=True,
)
def on_use_range_changed(self, props: dict) -> None:
"""Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.Range`."""
"""Synchronize the `active_kind` of input/output sockets, to either produce a `FK.Value` or a `FK.Range`."""
if self.inputs.get('WL') is not None:
active_input = self.inputs['WL']
else:
@ -108,46 +112,46 @@ class WaveConstantNode(base.MaxwellSimNode):
# Modify Active Kind(s)
## Input active_kind -> Value/Range
active_input_uses_range = active_input.active_kind == ct.FlowKind.Range
active_input_uses_range = active_input.active_kind == FK.Range
if active_input_uses_range != props['use_range']:
active_input.active_kind = (
ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
)
active_input.active_kind = FK.Range if props['use_range'] else FK.Value
## Output active_kind -> Value/Range
for active_output in self.outputs.values():
active_output_uses_range = active_output.active_kind == ct.FlowKind.Range
active_output_uses_range = active_output.active_kind == FK.Range
if active_output_uses_range != props['use_range']:
active_output.active_kind = (
ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
)
active_output.active_kind = FK.Range if props['use_range'] else FK.Value
####################
# - FlowKinds
# - FlowKind.Value
####################
@events.computes_output_socket(
'WL',
kind=ct.FlowKind.Value,
input_sockets={'WL', 'Freq'},
input_sockets_optional={'WL': True, 'Freq': True},
kind=FK.Value,
# Loaded
inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
input_sockets_optional={'WL', 'Freq'},
)
def compute_wl_value(self, input_sockets: dict) -> sp.Expr:
"""Compute a single wavelength value from either wavelength/frequency."""
has_wl = not ct.FlowSignal.check(input_sockets['WL'])
has_wl = not FS.check(input_sockets['WL'])
if has_wl:
return input_sockets['WL']
return sci_constants.vac_speed_of_light / input_sockets['Freq']
return spu.convert_to(
sci_constants.vac_speed_of_light / input_sockets['Freq'], spu.um
)
@events.computes_output_socket(
'Freq',
kind=ct.FlowKind.Value,
input_sockets={'WL', 'Freq'},
input_sockets_optional={'WL': True, 'Freq': True},
kind=FK.Value,
# Loaded
inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
input_sockets_optional={'WL', 'Freq'},
)
def compute_freq_value(self, input_sockets: dict) -> sp.Expr:
"""Compute a single frequency value from either wavelength/frequency."""
has_freq = not ct.FlowSignal.check(input_sockets['Freq'])
has_freq = not FS.check(input_sockets['Freq'])
if has_freq:
return input_sockets['Freq']
@ -155,62 +159,117 @@ class WaveConstantNode(base.MaxwellSimNode):
sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz
)
####################
# - FlowKind.Range
####################
@events.computes_output_socket(
'WL',
kind=ct.FlowKind.Range,
input_sockets={'WL', 'Freq'},
input_socket_kinds={
'WL': ct.FlowKind.Range,
'Freq': ct.FlowKind.Range,
},
input_sockets_optional={'WL': True, 'Freq': True},
kind=FK.Range,
# Loaded
inscks_kinds={'WL': FK.Range, 'Freq': FK.Range},
input_sockets_optional={'WL', 'Freq'},
)
def compute_wl_range(self, input_sockets: dict) -> sp.Expr:
def compute_wl_range(self, input_sockets) -> sp.Expr:
"""Compute wavelength range from either wavelength/frequency ranges."""
has_wl = not ct.FlowSignal.check(input_sockets['WL'])
has_wl = not FS.check(input_sockets['WL'])
if has_wl:
return input_sockets['WL']
freq = input_sockets['Freq']
return ct.RangeFlow(
start=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (freq.stop * freq.unit), spu.um
),
stop=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (freq.start * freq.unit), spu.um
),
steps=freq.steps,
scaling=freq.scaling,
unit=spu.um,
return freq.rescale(
lambda bound: sci_constants.vac_speed_of_light / bound,
reverse=True,
new_unit=spu.um,
)
@events.computes_output_socket(
'Freq',
kind=ct.FlowKind.Range,
kind=FK.Range,
input_sockets={'WL', 'Freq'},
input_socket_kinds={
'WL': ct.FlowKind.Range,
'Freq': ct.FlowKind.Range,
'WL': FK.Range,
'Freq': FK.Range,
},
input_sockets_optional={'WL': True, 'Freq': True},
)
def compute_freq_range(self, input_sockets: dict) -> sp.Expr:
"""Compute frequency range from either wavelength/frequency ranges."""
has_freq = not ct.FlowSignal.check(input_sockets['Freq'])
has_freq = not FS.check(input_sockets['Freq'])
if has_freq:
return input_sockets['Freq']
wl = input_sockets['WL']
return ct.RangeFlow(
start=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (wl.stop * wl.unit), spux.THz
),
stop=spux.scale_to_unit(
sci_constants.vac_speed_of_light / (wl.start * wl.unit), spux.THz
),
steps=wl.steps,
scaling=wl.scaling,
unit=spux.THz,
return wl.rescale(
lambda bound: sci_constants.vac_speed_of_light / bound,
reverse=True,
new_unit=spux.THz,
)
####################
# - FlowKind.Func
####################
# @events.computes_output_socket(
# 'WL',
# kind=FK.Func,
# # Loaded
# inscks_kinds={'WL': FK.Func, 'Freq': FK.Func},
# input_sockets_optional={'WL', 'Freq'},
# )
# def compute_wl_func(self, input_sockets: dict) -> ct.FuncFlow | FS:
# """Compute a single wavelength value from either wavelength/frequency."""
# wl = input_sockets['WL']
# has_wl = not FS.check(wl)
# if has_wl:
# return wl
# freq = input_sockets['Freq']
# has_freq = not FS.check(freq)
# if has_freq:
# return wl.compose_within(
# return spu.convert_to(
# sci_constants.vac_speed_of_light / input_sockets['Freq'], spu.um
# )
# return FS.FlowPending
# @events.computes_output_socket(
# 'Freq',
# kind=FK.Value,
# # Loaded
# inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
# input_sockets_optional={'WL', 'Freq'},
# )
# def compute_freq_value(self, input_sockets: dict) -> sp.Expr:
# """Compute a single frequency value from either wavelength/frequency."""
# has_freq = not FS.check(input_sockets['Freq'])
# if has_freq:
# return input_sockets['Freq']
# return spu.convert_to(
# sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz
# )
####################
# - FlowKind.Info
####################
@events.computes_output_socket(
'WL',
kind=FK.Info,
)
def compute_wl_info(self) -> ct.InfoFlow:
"""Just enough InfoFlow to enable `linked_capabilities`."""
return ct.InfoFlow(
output=sim_symbols.wl(spu.um),
)
@events.computes_output_socket(
'Freq',
kind=FK.Info,
)
def compute_freq_info(self) -> sp.Expr:
"""Compute frequency range from either wavelength/frequency ranges."""
return ct.InfoFlow(
output=sim_symbols.freq(spux.THz),
)
@ -220,4 +279,4 @@ class WaveConstantNode(base.MaxwellSimNode):
BL_REGISTER = [
WaveConstantNode,
]
BL_NODES = {ct.NodeType.WaveConstant: (ct.NodeCategory.MAXWELLSIM_INPUTS)}
BL_NODES = {ct.NodeType.WaveConstant: (ct.NodeCategory.MAXWELLSIM_UTILITIES)}

View File

@ -26,6 +26,9 @@ from .. import contracts as ct
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
####################
# - SocketDef
@ -171,6 +174,7 @@ class SocketDef(pyd.BaseModel, abc.ABC):
####################
# - Socket
####################
FLOW_ERROR_COLOR: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 1.0)
MANDATORY_PROPS: set[str] = {'socket_type', 'bl_label'}
@ -211,6 +215,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
(0, 0, 0, 0), use_prop_update=False
)
flow_error: bool = bl_cache.BLField(False, use_prop_update=False)
####################
# - Initialization
####################
@ -352,7 +358,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
## -> The tradeoff: No link if there is no InfoFlow.
if self.use_linked_capabilities:
info = self.compute_data(kind=ct.FlowKind.Info)
has_info = not ct.FlowSignal.check(info)
has_info = not FS.check(info)
if has_info:
incoming_capabilities = link.from_socket.linked_capabilities(info)
else:
@ -477,30 +483,10 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Run Socket Callbacks
self.on_socket_data_changed(socket_kinds)
# Mark Active FlowKind Links as Invalid
## -> Mark link as invalid (very red) if a FlowSignal is traveling.
## -> This helps explain why whatever isn't working isn't working.
## -> TODO: We need a different approach.
# log.debug(
# '[%s] Checking FlowKind Validity (socket_kinds=%s)',
# self.name,
# str(socket_kinds),
# )
# if self.is_linked and not self.is_output:
# link = self.links[0]
# linked_flow = self.compute_data(kind=self.active_kind)
# if (
# link.is_valid
# and self.active_kind in socket_kinds
# and ct.FlowSignal.check_single(linked_flow, ct.FlowSignal.FlowPending)
# ):
# node_tree = self.id_data
# node_tree.report_link_validity(link, False)
# elif not link.is_valid:
# node_tree = self.id_data
# node_tree.report_link_validity(link, True)
# Clear FlowErrors
## -> We should presume by default that the updated value is OK.
if self.flow_error:
bpy.app.timers.register(self.clear_flow_error)
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
"""Called when `ct.FlowEvent.DataChanged` flows through this socket.
@ -646,7 +632,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Returns:
An empty `ct.InfoFlow`.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
# Param
@property
@ -659,7 +645,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Returns:
An empty `ct.ParamsFlow`.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
####################
# - FlowKind: Auxiliary
@ -675,7 +661,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
@value.setter
def value(self, value: ct.ValueFlow) -> None:
@ -701,7 +687,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
@array.setter
def array(self, value: ct.ArrayFlow) -> None:
@ -727,7 +713,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
@lazy_func.setter
def lazy_func(self, lazy_func: ct.FuncFlow) -> None:
@ -753,7 +739,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Raises:
NotImplementedError: When used without being overridden.
"""
return ct.FlowSignal.NoFlow
return FS.NoFlow
@lazy_range.setter
def lazy_range(self, value: ct.RangeFlow) -> None:
@ -818,29 +804,42 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
"""
# Compute Output Socket
if self.is_output:
return self.node.compute_output(self.name, kind=kind)
flow = self.node.compute_output(self.name, kind=kind)
# Compute Input Socket
## -> Unlinked: Retrieve Socket Value
if not self.is_linked:
return self._compute_data(kind)
elif not self.is_linked:
flow = self._compute_data(kind)
## Linked: Compute Data on Linked Socket
## -> Capabilities are guaranteed compatible by 'allow_link_add'.
## -> There is no point in rechecking every time data flows.
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
else:
# Linked: Compute Data on Linked Socket
## -> Capabilities are guaranteed compatible by 'allow_link_add'.
## -> There is no point in rechecking every time data flows.
linked_values = [link.from_socket.compute_data(kind) for link in self.links]
# Return Single Value / List of Values
## -> Multi-input sockets are not (yet) supported.
if linked_values:
return linked_values[0]
# Return Single Value / List of Values
## -> Multi-input sockets are not (yet) supported.
if linked_values: # noqa: SIM108
flow = linked_values[0]
# Edge Case: While Dragging Link (but not yet removed)
## While the user is dragging a link:
## - self.is_linked = True, since the user hasn't confirmed anything.
## - self.links will be empty, since the link object was freed.
## When this particular condition is met, pretend that we're not linked.
return self._compute_data(kind)
# Edge Case: While Dragging Link (but not yet removed)
## While the user is dragging a link:
## - self.is_linked = True, since the user hasn't confirmed anything.
## - self.links will be empty, since the link object was freed.
## When this particular condition is met, pretend that we're not linked.
else:
flow = self._compute_data(kind)
if FS.check_single(flow, FS.FlowPending) and not self.flow_error:
bpy.app.timers.register(self.declare_flow_error)
return flow
def declare_flow_error(self):
self.flow_error = True
def clear_flow_error(self):
self.flow_error = False
####################
# - UI - Color
@ -858,6 +857,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
Notes:
Called by Blender to call the socket color.
"""
if self.flow_error:
return FLOW_ERROR_COLOR
if self.use_socket_color:
return self.socket_color
return ct.SOCKET_COLORS[self.socket_type]
@ -978,7 +979,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Info Drawing
if self.use_info_draw:
info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info):
if not FS.check(info):
self.draw_info(info, col)
def draw_output(
@ -1009,7 +1010,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
# Draw FlowKind.Info related Information
if self.use_info_draw:
info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info):
if not FS.check(info):
self.draw_info(info, col)
####################

View File

@ -38,7 +38,7 @@ class StringBLSocket(base.MaxwellSimSocket):
# - Socket UI
####################
def draw_label_row(self, label_col_row: bpy.types.UILayout, text: str) -> None:
label_col_row.prop(self, 'raw_value', text=text)
label_col_row.prop(self, self.blfields['raw_value'], text=text)
####################
# - Computation of Default Value

View File

@ -25,11 +25,17 @@ import sympy as sp
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import frozendict
from .. import contracts as ct
from . import base
log = logger.get(__name__)
FK = ct.FlowKind
MT = spux.MathType
UI_FLOAT_EPS = sp.Float(0.0001, 1)
UI_FLOAT_PREC = 4
Int2: typ.TypeAlias = tuple[int, int]
Int3: typ.TypeAlias = tuple[int, int, int]
@ -44,19 +50,21 @@ Float32: typ.TypeAlias = tuple[
####################
# - Utilitives
# - Utilities
####################
def unicode_superscript(n: int) -> str:
"""Transform an integer into its unicode-based superscript character."""
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
def _check_sym_oo(sym):
return sym.is_real or sym.is_rational or sym.is_integer
class InfoDisplayCol(enum.StrEnum):
"""Valid columns for specifying displayed information from an `ct.InfoFlow`."""
"""Valid columns for specifying displayed information from an `ct.InfoFlow`.
Attributes:
Length: Display the size of the dimensional index.
MathType: Display the `MT` of the dimensional symbol.
Unit: Display the unit of the dimensional symbol.
"""
Length = enum.auto()
MathType = enum.auto()
@ -112,25 +120,33 @@ class ExprBLSocket(base.MaxwellSimSocket):
use_socket_color = True
####################
# - Socket Interface
# - Identifier
####################
size: spux.NumberSize1D = bl_cache.BLField(spux.NumberSize1D.Scalar)
mathtype: spux.MathType = bl_cache.BLField(spux.MathType.Real)
mathtype: MT = bl_cache.BLField(MT.Real)
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
####################
# - Output Symbol
####################
@bl_cache.cached_bl_property(
## -> CAREFUL: 'output_sym' changes recompiles `FuncFlow`.
depends_on={
'active_kind',
'symbols',
'raw_value_spstr',
'raw_min_spstr',
'raw_max_spstr',
# Identity
'output_name',
'active_kind',
'mathtype',
'physical_type',
'unit',
'size',
'value',
# Symbols / Symbolic Expression
'symbols',
'raw_value_spstr',
'raw_min_spstr',
'raw_max_spstr',
# Domain
'domain',
'steps', ## -> Func needs to recompile anyway if steps changes.
}
)
def output_sym(self) -> sim_symbols.SimSymbol | None:
@ -142,12 +158,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
NotImplementedError: When `active_kind` is neither `Value`, `Func`, or `Range`.
"""
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.symbols:
case FK.Value | FK.Func if self.symbols:
return self._parse_expr_symbol(
self._parse_expr_str(self.raw_value_spstr)
)
case ct.FlowKind.Value | ct.FlowKind.Func if not self.symbols:
case FK.Value | FK.Func if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
@ -155,17 +171,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
unit=self.unit,
rows=self.size.rows,
cols=self.size.cols,
is_constant=True,
## TODO: Should we set preview values
exclude_zero=(
not self.value.is_zero
if self.value.is_zero is not None
else False
),
## TODO: Does this 0-check work for matrix elements?
domain=self.domain,
)
case ct.FlowKind.Range if self.symbols:
case FK.Range if self.symbols:
## TODO: Support RangeFlow
## -- It's hard; we need a min-span set over bound domains.
## -- We... Don't use this anywhere. Yet?
@ -178,29 +187,181 @@ class ExprBLSocket(base.MaxwellSimSocket):
msg = 'RangeFlow support not yet implemented for when self.symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if not self.symbols:
case FK.Range if not self.symbols:
return sim_symbols.SimSymbol(
sym_name=self.output_name,
mathtype=self.mathtype,
physical_type=self.physical_type,
unit=self.unit,
rows=self.lazy_range.steps,
rows=self.steps,
cols=1,
exclude_zero=not self.lazy_range.is_always_nonzero,
domain=self.domain,
)
####################
# - Domain
####################
exclude_zero: bool = bl_cache.BLField(True)
abs_min_infinite: bool = bl_cache.BLField(True)
abs_max_infinite: bool = bl_cache.BLField(True)
abs_min_infinite_im: bool = bl_cache.BLField(True)
abs_max_infinite_im: bool = bl_cache.BLField(True)
abs_min_closed: bool = bl_cache.BLField(True)
abs_max_closed: bool = bl_cache.BLField(True)
abs_min_closed_im: bool = bl_cache.BLField(True)
abs_max_closed_im: bool = bl_cache.BLField(True)
abs_min_int: int = bl_cache.BLField(0)
abs_min_rat: Int2 = bl_cache.BLField((0, 1))
abs_min_float: float = bl_cache.BLField(0.0, float_prec=UI_FLOAT_PREC)
abs_min_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=UI_FLOAT_PREC)
abs_max_int: int = bl_cache.BLField(0)
abs_max_rat: Int2 = bl_cache.BLField((0, 1))
abs_max_float: float = bl_cache.BLField(0.0, float_prec=UI_FLOAT_PREC)
abs_max_complex: Float2 = bl_cache.BLField((0.0, 0.0), float_prec=UI_FLOAT_PREC)
@bl_cache.cached_bl_property(
depends_on={
'mathtype',
'abs_min_infinite',
'abs_min_infinite_im',
'abs_min_int',
'abs_min_rat',
'abs_min_float',
'abs_min_complex',
}
)
def abs_inf(self) -> sp.Integer | sp.Rational | sp.Float | spux.ComplexNumber:
"""Deduce the infimum of values expressable by this socket."""
match self.mathtype:
case MT.Integer | MT.Rational | MT.Real if self.abs_min_infinite:
return -sp.oo
case MT.Integer:
abs_min = sp.Integer(self.abs_min_int)
case MT.Rational:
abs_min = sp.Rational(*self.abs_min_rat)
case MT.Real:
abs_min = sp.Float(self.abs_min_float, UI_FLOAT_PREC)
case MT.Complex:
cplx = self.abs_min_complex
abs_min_re = (
sp.Float(cplx[0], UI_FLOAT_PREC)
if not self.abs_min_infinite
else -sp.oo
)
abs_min_im = (
sp.Float(cplx[1], UI_FLOAT_PREC)
if not self.abs_min_infinite_im
else -sp.oo
)
abs_min = abs_min_re + sp.I * abs_min_im
return abs_min
@bl_cache.cached_bl_property(
depends_on={
'mathtype',
'abs_max_infinite',
'abs_max_infinite_im',
'abs_max_int',
'abs_max_rat',
'abs_max_float',
'abs_max_complex',
}
)
def abs_sup(self) -> sp.Integer | sp.Rational | sp.Float | spux.ComplexNumber:
"""Deduce the infimum of values expressable by this socket."""
match self.mathtype:
case MT.Integer | MT.Rational | MT.Real if self.abs_max_infinite:
return sp.oo
case MT.Integer:
abs_max = sp.Integer(self.abs_max_int)
case MT.Rational:
abs_max = sp.Rational(*self.abs_max_rat)
case MT.Real:
abs_max = sp.Float(self.abs_max_float, UI_FLOAT_PREC)
case MT.Complex:
cplx = self.abs_max_complex
abs_max_re = (
sp.Float(cplx[0], UI_FLOAT_PREC)
if not self.abs_max_infinite
else sp.oo
)
abs_max_im = (
sp.Float(cplx[1], UI_FLOAT_PREC)
if not self.abs_max_infinite_im
else sp.oo
)
abs_max = abs_max_re + sp.I * abs_max_im
return abs_max
@bl_cache.cached_bl_property(
depends_on={
'abs_inf',
'abs_sup',
'exclude_zero',
'abs_min_closed',
'abs_max_closed',
'abs_min_closed_im',
'abs_max_closed_im',
}
)
def domain(self) -> spux.BlessedSet:
"""Deduce the domain of the socket's expression."""
match self.mathtype:
case MT.Integer:
domain = spux.BlessedSet(
sp.Range(
self.abs_inf if self.abs_min_closed else self.abs_inf + 1,
self.abs_sup + 1 if self.abs_max_closed else self.abs_sup,
)
)
case MT.Rational | MT.Real:
domain = spux.BlessedSet(
sp.Interval(
self.abs_inf,
self.abs_sup,
left_open=not self.abs_min_closed,
right_open=not self.abs_max_closed,
)
)
case MT.Complex:
domain = spux.BlessedSet.reals_to_complex(
sp.Interval(
sp.re(self.abs_inf),
sp.re(self.abs_sup),
left_open=not self.abs_min_closed,
right_open=not self.abs_max_closed,
),
sp.Interval(
sp.im(self.abs_inf),
sp.im(self.abs_sup),
left_open=not self.abs_min_closed_im,
right_open=not self.abs_max_closed_im,
),
)
if self.exclude_zero:
return domain - sp.FiniteSet(0)
return domain
####################
# - Value|Range Swapper
####################
use_value_range_swapper: bool = bl_cache.BLField(False)
selected_value_range: ct.FlowKind = bl_cache.BLField(
enum_cb=lambda self, _: self._value_or_range(),
selected_value_range: FK = bl_cache.BLField(
enum_cb=lambda self, _: self.search_value_or_range(),
)
def _value_or_range(self):
def search_value_or_range(self):
"""Either `FlowKind.Value` or `FlowKind.Range`."""
return [
flow_kind.bl_enum_element(i)
for i, flow_kind in enumerate([ct.FlowKind.Value, ct.FlowKind.Range])
for i, flow_kind in enumerate([FK.Value, FK.Range])
]
####################
@ -369,7 +530,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
# - Event Callbacks
####################
def on_socket_data_changed(self, socket_kinds: set[ct.FlowKind]) -> None:
def on_socket_data_changed(self, socket_kinds: set[FK]) -> None:
"""Alter the socket's color in response to flow.
- `FlowKind.Info`: Any change causes the socket color to be updated with the physical type of the output symbol.
@ -381,8 +542,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
## NOTE: Depends on suppressed on_prop_changed
if ct.FlowKind.Info in socket_kinds:
info = self.compute_data(kind=ct.FlowKind.Info)
if FK.Info in socket_kinds:
info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
# Alter Color
@ -443,7 +604,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False):
"""Cast the given expression to the appropriate raw value, with scaling guided by `self.unit`."""
pyvalue = spux.scale_to_unit(expr, self.unit)
pyvalue = spux.scale_to_unit(expr, self.unit, cast_to_pytype=True)
# Cast complex -> tuple[float, float]
## -> We can't set complex to BLProps.
@ -452,6 +613,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
isinstance(pyvalue, int | float) and force_complex
):
return (pyvalue.real, pyvalue.imag)
if isinstance(pyvalue, tuple) and all(
isinstance(v, complex)
or (isinstance(pyvalue, int | float) and force_complex)
@ -557,6 +719,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'unit',
'mathtype',
'size',
'domain',
'raw_value_sp',
'raw_value_int',
'raw_value_rat',
@ -601,20 +764,21 @@ class ExprBLSocket(base.MaxwellSimSocket):
if self.size is NS.Vec4:
return ct.FlowSignal.NoFlow
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
MT_Z = MT.Integer
MT_Q = MT.Rational
MT_R = MT.Real
MT_C = MT.Complex
Z = sp.Integer
Q = sp.Rational
R = sp.RealNumber
return {
R = sp.Float
raw_value = {
NS.Scalar: {
MT_Z: lambda: Z(self.raw_value_int),
MT_Q: lambda: Q(self.raw_value_rat[0], self.raw_value_rat[1]),
MT_R: lambda: R(self.raw_value_float),
MT_R: lambda: R(self.raw_value_float, UI_FLOAT_PREC),
MT_C: lambda: (
self.raw_value_complex[0] + sp.I * self.raw_value_complex[1]
R(self.raw_value_complex[0], UI_FLOAT_PREC)
+ sp.I * R(self.raw_value_complex[1], UI_FLOAT_PREC)
),
},
NS.Vec2: {
@ -622,9 +786,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q: lambda: sp.ImmutableMatrix(
[Q(q[0], q[1]) for q in self.raw_value_rat2]
),
MT_R: lambda: sp.ImmutableMatrix([R(r) for r in self.raw_value_float2]),
MT_R: lambda: sp.ImmutableMatrix(
[R(r, UI_FLOAT_PREC) for r in self.raw_value_float2]
),
MT_C: lambda: sp.ImmutableMatrix(
[c[0] + sp.I * c[1] for c in self.raw_value_complex2]
[
R(c[0], UI_FLOAT_PREC) + sp.I * R(c[1], UI_FLOAT_PREC)
for c in self.raw_value_complex2
]
),
},
NS.Vec3: {
@ -632,12 +801,21 @@ class ExprBLSocket(base.MaxwellSimSocket):
MT_Q: lambda: sp.ImmutableMatrix(
[Q(q[0], q[1]) for q in self.raw_value_rat3]
),
MT_R: lambda: sp.ImmutableMatrix([R(r) for r in self.raw_value_float3]),
MT_R: lambda: sp.ImmutableMatrix(
[R(r, UI_FLOAT_PREC) for r in self.raw_value_float3]
),
MT_C: lambda: sp.ImmutableMatrix(
[c[0] + sp.I * c[1] for c in self.raw_value_complex3]
[
R(c[0], UI_FLOAT_PREC) + sp.I * R(c[1], UI_FLOAT_PREC)
for c in self.raw_value_complex3
]
),
},
}[self.size][self.mathtype]() * (self.unit if self.unit is not None else 1)
}[self.size][self.mathtype]()
if raw_value not in self.domain:
return ct.FlowSignal.FlowPending
return raw_value * self.unit_factor
@value.setter
def value(self, expr: spux.SympyExpr) -> None:
@ -650,7 +828,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_value_spstr = sp.sstr(expr)
else:
NS = spux.NumberSize1D
MT = spux.MathType
match (self.size, self.mathtype):
case (NS.Scalar, MT.Integer):
self.raw_value_int = self._to_raw_value(expr)
@ -694,6 +871,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
'unit',
'mathtype',
'size',
'domain',
'steps',
'scaling',
'raw_min_sp',
@ -723,10 +901,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
symbols=self.symbols,
)
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
MT_Z = MT.Integer
MT_Q = MT.Rational
MT_R = MT.Real
MT_C = MT.Complex
Z = sp.Integer
Q = sp.Rational
R = sp.RealNumber
@ -739,10 +917,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
bound[0] + sp.I * bound[1] for bound in self.raw_range_complex
],
}[self.mathtype]()
if min_bound not in self.domain or max_bound not in self.domain:
return ct.FlowSignal.FlowPending
return ct.RangeFlow(
start=min_bound,
stop=max_bound,
start=sp.Float(min_bound, 4),
stop=sp.Float(max_bound, 4),
steps=self.steps,
scaling=self.scaling,
unit=self.unit,
@ -763,10 +943,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
self.raw_max_spstr = sp.sstr(lazy_range.stop)
else:
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
MT_Z = MT.Integer
MT_Q = MT.Rational
MT_R = MT.Real
MT_C = MT.Complex
unit = lazy_range.unit if lazy_range.unit is not None else 1
if self.mathtype == MT_Z:
@ -807,6 +987,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
func_output=self.output_sym,
supports_jax=True,
)
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -821,27 +1007,23 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if (
not ct.FlowSignal.check(self.value)
):
case FK.Value | FK.Func if (not ct.FlowSignal.check(self.value)):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.value],
symbols=set(self.sorted_symbols),
)
case ct.FlowKind.Range if self.sorted_symbols:
case FK.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
case FK.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.output_sym.sp_symbol_matsym],
symbols={self.output_sym},
).realize_partial({self.output_sym: self.lazy_range})
).realize_partial(frozendict({self.output_sym: self.lazy_range}))
return ct.FlowSignal.FlowPending
@ -861,17 +1043,17 @@ class ExprBLSocket(base.MaxwellSimSocket):
"""
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func:
case FK.Value | FK.Func:
return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym,
)
case ct.FlowKind.Range if self.sorted_symbols:
case FK.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
case FK.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(
@ -893,11 +1075,11 @@ class ExprBLSocket(base.MaxwellSimSocket):
socket_type=self.socket_type,
active_kind=self.active_kind,
allow_out_to_in={
ct.FlowKind.Func: ct.FlowKind.Value,
FK.Func: FK.Value,
},
allow_out_to_in_if_matches={
ct.FlowKind.Value: (
ct.FlowKind.Func,
FK.Value: (
FK.Func,
(
info.output.physical_type,
info.output.mathtype,
@ -917,8 +1099,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
output_sym = self.output_sym
if output_sym is not None:
allow_out_to_in_if_matches = {
ct.FlowKind.Value: (
ct.FlowKind.Func,
FK.Value: (
FK.Func,
(
output_sym.physical_type,
output_sym.mathtype,
@ -934,7 +1116,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
socket_type=self.socket_type,
active_kind=self.active_kind,
allow_out_to_in={
ct.FlowKind.Func: ct.FlowKind.Value,
FK.Func: FK.Value,
},
allow_out_to_in_if_matches=allow_out_to_in_if_matches,
)
@ -964,8 +1146,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes:
Whether information about the expression passing through a linked socket is shown is governed by `self.show_info_columns`.
"""
if self.active_kind is ct.FlowKind.Func:
info = self.compute_data(kind=ct.FlowKind.Info)
if self.active_kind is FK.Func:
info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
@ -999,8 +1181,8 @@ class ExprBLSocket(base.MaxwellSimSocket):
Notes:
Whether information about the expression passing through a linked socket is shown is governed by `self.show_info_columns`.
"""
if self.active_kind is ct.FlowKind.Func:
info = self.compute_data(kind=ct.FlowKind.Info)
if self.active_kind is FK.Func:
info = self.compute_data(kind=FK.Info)
has_info = not ct.FlowSignal.check(info)
if has_info:
@ -1054,7 +1236,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
else:
NS = spux.NumberSize1D
MT = spux.MathType
match (self.size, self.mathtype):
case (NS.Scalar, MT.Integer):
col.prop(self, self.blfields['raw_value_int'], text='')
@ -1115,10 +1296,10 @@ class ExprBLSocket(base.MaxwellSimSocket):
col.prop(self, self.blfields['raw_max_spstr'], text='')
else:
MT_Z = spux.MathType.Integer
MT_Q = spux.MathType.Rational
MT_R = spux.MathType.Real
MT_C = spux.MathType.Complex
MT_Z = MT.Integer
MT_Q = MT.Rational
MT_R = MT.Real
MT_C = MT.Complex
if self.mathtype == MT_Z:
col.prop(self, self.blfields['raw_range_int'], text='')
elif self.mathtype == MT_Q:
@ -1173,7 +1354,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
"""Visualize the `InfoFlow` information passing through the socket."""
if (
self.active_kind is ct.FlowKind.Func
self.active_kind is FK.Func
and self.show_info_columns
and (self.is_linked or self.is_output)
):
@ -1219,7 +1400,7 @@ class ExprSocketDef(base.SocketDef):
# Socket Interface
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
mathtype: spux.MathType = spux.MathType.Real
mathtype: MT = MT.Real
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
default_unit: spux.Unit | None = None
@ -1227,8 +1408,6 @@ class ExprSocketDef(base.SocketDef):
# FlowKind: Value
default_value: spux.SympyExpr = 0
abs_min: spux.SympyExpr | None = None
abs_max: spux.SympyExpr | None = None
# FlowKind: Range
default_min: spux.SympyExpr = 0
@ -1236,6 +1415,15 @@ class ExprSocketDef(base.SocketDef):
default_steps: int = 2
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
# Domain
abs_min: spux.SympyExpr | None = None
abs_max: spux.SympyExpr | None = None
abs_min_closed: bool = True
abs_max_closed: bool = True
abs_min_closed_im: bool = True
abs_max_closed_im: bool = True
exclude_zero: bool = False
# UI
show_name_selector: bool = False
show_func_ui: bool = True
@ -1347,14 +1535,14 @@ class ExprSocketDef(base.SocketDef):
# Coerce Number -> Column 0-Vector
## -> TODO: We don't strictly know if default_value is a number.
if len(self.size.shape) == 1:
self.default_value = self.default_value * sp.Matrix.ones(
self.default_value = self.default_value * sp.ImmutableMatrix.ones(
self.size.shape[0], 1
)
# Coerce Number -> 0-Matrix
## -> TODO: We don't strictly know if default_value is a number.
if len(self.size.shape) > 1:
self.default_value = self.default_value * sp.Matrix.ones(
self.default_value = self.default_value * sp.ImmutableMatrix.ones(
*self.size.shape
)
@ -1367,9 +1555,9 @@ class ExprSocketDef(base.SocketDef):
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`, after coersion to the correct Python type using `self.mathtype.coerce_compatible_pyobj()`.
Raises:
ValueError: If `self.default_value` has no obvious, coerceable `spux.MathType` compatible with `self.mathtype`, as determined by `spux.MathType.has_mathtype`.
ValueError: If `self.default_value` has no obvious, coerceable `MT` compatible with `self.mathtype`, as determined by `MT.has_mathtype`.
"""
mathtype_guide = spux.MathType.has_mathtype(self.default_value)
mathtype_guide = MT.has_mathtype(self.default_value)
# None: No Obvious Mathtype
if mathtype_guide is None:
@ -1378,7 +1566,7 @@ class ExprSocketDef(base.SocketDef):
# PyType: Coerce from PyType
if mathtype_guide == 'pytype':
dv_mathtype = spux.MathType.from_pytype(type(self.default_value))
dv_mathtype = MT.from_pytype(type(self.default_value))
if self.mathtype.is_compatible(dv_mathtype):
self.default_value = sp.S(
self.mathtype.coerce_compatible_pyobj(self.default_value)
@ -1389,7 +1577,7 @@ class ExprSocketDef(base.SocketDef):
# Expr: Merely Check MathType Compatibility
if mathtype_guide == 'expr':
dv_mathtype = spux.MathType.from_expr(self.default_value)
dv_mathtype = MT.from_expr(self.default_value)
if not self.mathtype.is_compatible(dv_mathtype):
msg = f'ExprSocket: Mathtype {dv_mathtype} of default value expression {self.default_value} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg)
@ -1416,11 +1604,11 @@ class ExprSocketDef(base.SocketDef):
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`.
Raises:
ValueError: If `self.default_value` has no obvious `spux.MathType`, as determined by `spux.MathType.has_mathtype`.
ValueError: If `self.default_value` has no obvious `MT`, as determined by `MT.has_mathtype`.
"""
new_bounds = [None, None]
for i, bound in enumerate([self.default_min, self.default_max]):
mathtype_guide = spux.MathType.has_mathtype(bound)
mathtype_guide = MT.has_mathtype(bound)
# None: No Obvious Mathtype
if mathtype_guide is None:
@ -1429,7 +1617,7 @@ class ExprSocketDef(base.SocketDef):
# PyType: Coerce from PyType
if mathtype_guide == 'pytype':
dv_mathtype = spux.MathType.from_pytype(type(bound))
dv_mathtype = MT.from_pytype(type(bound))
if self.mathtype.is_compatible(dv_mathtype):
new_bounds[i] = sp.S(self.mathtype.coerce_compatible_pyobj(bound))
else:
@ -1438,18 +1626,18 @@ class ExprSocketDef(base.SocketDef):
# Expr: Merely Check MathType Compatibility
if mathtype_guide == 'expr':
dv_mathtype = spux.MathType.from_expr(bound)
dv_mathtype = MT.from_expr(bound)
if not self.mathtype.is_compatible(dv_mathtype):
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
raise ValueError(msg)
# Coerce from Infinite
if isinstance(bound, spux.SympyType):
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
if bound.is_infinite and self.mathtype is MT.Integer:
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
if bound.is_infinite and self.mathtype is MT.Rational:
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
if bound.is_infinite and self.mathtype is spux.MathType.Real:
if bound.is_infinite and self.mathtype is MT.Real:
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
if new_bounds[0] is not None:
@ -1468,10 +1656,7 @@ class ExprSocketDef(base.SocketDef):
"""
# Check ActiveKind and Size
## -> NOTE: This doesn't protect against dynamic changes to either.
if (
self.active_kind is ct.FlowKind.Range
and self.size is not spux.NumberSize1D.Scalar
):
if self.active_kind is FK.Range and self.size is not spux.NumberSize1D.Scalar:
msg = "Can't have a non-Scalar size when Range is set as the active kind."
raise ValueError(msg)
@ -1524,8 +1709,61 @@ class ExprSocketDef(base.SocketDef):
bl_socket.size = self.size
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
bl_socket.active_unit = bl_cache.Signal.ResetEnumItems
bl_socket.unit = bl_cache.Signal.InvalidateCache
bl_socket.unit_factor = bl_cache.Signal.InvalidateCache
bl_socket.symbols = self.default_symbols
# Domain
bl_socket.exclude_zero = self.exclude_zero
if self.abs_min is None:
bl_socket.abs_min_infinite = True
bl_socket.abs_min_infinite_im = True
else:
bl_socket.abs_min_closed = self.abs_min_closed
if self.abs_max is None:
bl_socket.abs_max_infinite = True
bl_socket.abs_max_infinite_im = True
else:
bl_socket.abs_max_closed = self.abs_max_closed
match self.mathtype:
case MT.Integer if self.abs_min is not None:
bl_socket.abs_min_int = int(self.abs_min)
case MT.Integer if self.abs_max is not None:
bl_socket.abs_max_int = int(self.abs_max)
case MT.Rational if self.abs_min is not None:
bl_socket.abs_min_rat = (
self.abs_min.numerator,
self.abs_min.denominator,
)
case MT.Rational if self.abs_max is not None:
bl_socket.abs_max_rat = (
self.abs_max.numerator,
self.abs_max.denominator,
)
case MT.Real if self.abs_min is not None:
bl_socket.abs_min_float = float(self.abs_min)
case MT.Real if self.abs_max is not None:
bl_socket.abs_max_float = float(self.abs_max)
case MT.Complex if self.abs_min is not None:
bl_socket.abs_min_complex = (
float(sp.re(self.abs_min)),
float(sp.im(self.abs_min)),
)
bl_socket.abs_min_closed_im = self.abs_min_closed_im
case MT.Complex if self.abs_max is not None:
bl_socket.abs_max_complex = (
float(sp.re(self.abs_max)),
float(sp.im(self.abs_max)),
)
bl_socket.abs_max_closed_im = self.abs_max_closed_im
# FlowKind.Value
## -> We must take units into account when setting bl_socket.value
if self.physical_type is not spux.PhysicalType.NonPhysical:
@ -1570,7 +1808,9 @@ class ExprSocketDef(base.SocketDef):
and cmp('show_func_ui')
and cmp('show_info_columns')
and cmp('show_name_selector')
and cmp('show_name_selector')
and bl_socket.use_info_draw
## TODO: Include domain?
)

View File

@ -14,16 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import math
import bpy
import jax.numpy as jnp
import scipy as sc
import sympy as sp
import sympy.physics.units as spu
import tidy3d as td
import tidy3d.plugins.adjoint as tdadj
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import sympy_extra as spux
from ... import contracts as ct
@ -51,49 +48,44 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
####################
eps_rel: tuple[float, float] = bl_cache.BLField((1.0, 0.0), float_prec=2)
differentiable: bool = bl_cache.BLField(False)
####################
# - FlowKinds
####################
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
@bl_cache.cached_bl_property(depends_on={'eps_rel'})
def value(self) -> td.Medium:
eps_r_re = self.eps_rel[0]
conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
# conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
if self.differentiable:
return tdadj.JaxMedium(
permittivity_jax=jnp.array(eps_r_re, dtype=float),
conductivity_jax=jnp.array(conductivity, dtype=float),
)
return td.Medium(
permittivity=eps_r_re,
conductivity=conductivity,
# conductivity=conductivity,
)
@value.setter
def value(self, eps_rel: tuple[float, float]) -> None:
self.eps_rel = eps_rel
@bl_cache.cached_bl_property(depends_on={'value'})
@bl_cache.cached_bl_property()
def lazy_func(self) -> ct.FuncFlow:
return ct.FuncFlow(
func=lambda: self.value,
supports_jax=self.differentiable,
func=lambda eps_r_re, eps_r_im: td.Medium(
permittivity=eps_r_re,
# conductivity=FIXED_FREQ * eps_r_im,
),
func_args=[sim_symbols.rel_eps_re(None), sim_symbols.rel_eps_im(None)],
)
@bl_cache.cached_bl_property(depends_on={'differentiable'})
@bl_cache.cached_bl_property(depends_on={'eps_rel'})
def params(self) -> ct.FuncFlow:
return ct.ParamsFlow(
func_args=[sp.S(self.eps_rel[0]), sp.S(self.eps_rel[1])],
)
return ct.ParamsFlow()
####################
# - UI
####################
def draw_value(self, col: bpy.types.UILayout) -> None:
col.prop(
self, self.blfields['differentiable'], text='Differentiable', toggle=True
)
col.separator()
split = col.split(factor=0.25, align=False)
_col = split.column(align=True)

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `Tidy3DCloudTaskBLSocket`."""
import enum
import bpy
@ -29,49 +31,59 @@ from .. import base
# - Operators
####################
class ReloadFolderList(bpy.types.Operator):
"""Reload the list of available folders."""
bl_idname = ct.OperatorType.SocketReloadCloudFolderList
bl_label = 'Reload Tidy3D Folder List'
bl_description = 'Reload the the cached Tidy3D folder list'
@classmethod
def poll(cls, context):
"""Allow reloading the folder list when online, authenticated, and attached to a `Tidy3DCloudTask` socket."""
return (
tdcloud.IS_ONLINE
and tdcloud.IS_AUTHENTICATED
and hasattr(context, 'socket')
and hasattr(context.socket, 'socket_type')
and context.socket.socket_type == ct.SocketType.Tidy3DCloudTask
and context.socket.socket_type is ct.SocketType.Tidy3DCloudTask
)
def execute(self, context):
"""Update the folder list, as well as any tasks attached to any existing folder ID."""
bl_socket = context.socket
tdcloud.TidyCloudFolders.update_folders()
tdcloud.TidyCloudTasks.update_tasks(bl_socket.existing_folder_id)
bl_socket.existing_folder_id = bl_cache.Signal.ResetEnumItems
bl_socket.existing_folder_id = bl_cache.Signal.InvalidateCache
if bl_socket.existing_folder_id is not None:
tdcloud.TidyCloudTasks.update_tasks(bl_socket.existing_folder_id)
bl_socket.existing_task_id = bl_cache.Signal.ResetEnumItems
bl_socket.existing_task_id = bl_cache.Signal.InvalidateCache
return {'FINISHED'}
class Authenticate(bpy.types.Operator):
"""Authenticate with the Tidy3D API."""
bl_idname = ct.OperatorType.SocketCloudAuthenticate
bl_label = 'Authenticate Tidy3D'
bl_description = 'Authenticate the Tidy3D Web API from a Cloud Task socket'
@classmethod
def poll(cls, context):
"""Allow authenticating when online, not authenticated, and attached to a `Tidy3DCloudTask` socket."""
return (
not tdcloud.IS_AUTHENTICATED
tdcloud.IS_ONLINE
and not tdcloud.IS_AUTHENTICATED
and hasattr(context, 'socket')
and hasattr(context.socket, 'socket_type')
and context.socket.socket_type == ct.SocketType.Tidy3DCloudTask
and context.socket.socket_type is ct.SocketType.Tidy3DCloudTask
)
def execute(self, context):
"""Try authenticating the socket with the web service."""
bl_socket = context.socket
if not tdcloud.check_authentication():
@ -104,11 +116,14 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
bl_label = 'Tidy3D Cloud Task'
####################
# - Properties
# - Properties: Authentication
####################
api_key: str = bl_cache.BLField('', str_secret=True)
should_exist: bool = bl_cache.BLField(False)
####################
# - Properties: Existance
####################
should_exist: bool = bl_cache.BLField(False)
new_task_name: str = bl_cache.BLField('')
####################
@ -119,10 +134,11 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
)
def search_cloud_folders(self) -> list[ct.BLEnumElement]:
"""Get all Tidy3D cloud folders."""
if tdcloud.IS_AUTHENTICATED:
return [
(
cloud_folder.folder_id,
folder_id,
cloud_folder.folder_name,
f'Folder {cloud_folder.folder_name} (ID={folder_id})',
'',
@ -139,10 +155,12 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
# - Properties: Cloud Tasks
####################
existing_task_id: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_cloud_tasks()
enum_cb=lambda self, _: self.search_cloud_tasks(),
cb_depends_on={'existing_folder_id'},
)
def search_cloud_tasks(self) -> list[ct.BLEnumElement]:
"""Retrieve all Tidy3D cloud folders from the cloud service."""
if self.existing_folder_id is None or not tdcloud.IS_AUTHENTICATED:
return []
@ -200,6 +218,7 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
####################
@bl_cache.cached_bl_property(depends_on={'active_kind', 'should_exist'})
def capabilities(self) -> ct.CapabilitiesFlow:
"""Cloud task sockets are compatible by presumtion of existance."""
return ct.CapabilitiesFlow(
socket_type=self.socket_type,
active_kind=self.active_kind,
@ -217,48 +236,41 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
def value(
self,
) -> ct.NewSimCloudTask | tdcloud.CloudTask | ct.FlowSignal:
"""Return either the existing cloud task, or an object describing how to make it."""
if tdcloud.IS_AUTHENTICATED:
# Retrieve Folder
cloud_folder = tdcloud.TidyCloudFolders.folders().get(
self.existing_folder_id
)
if cloud_folder is None:
if cloud_folder is not None:
# Case: New Task
if not self.should_exist:
return ct.NewSimCloudTask(
task_name=self.new_task_name, cloud_folder=cloud_folder
)
# Case: Existing Task
if self.existing_task_id is not None:
cloud_task = tdcloud.TidyCloudTasks.tasks(cloud_folder)[
self.existing_task_id
]
if cloud_task is not None:
return cloud_task
else:
return ct.FlowSignal.NoFlow ## Folder deleted somewhere else
# Case: New Task
if not self.should_exist:
return ct.NewSimCloudTask(
task_name=self.new_task_name, cloud_folder=cloud_folder
)
# Case: Existing Task
if self.existing_task_id is not None:
cloud_task = tdcloud.TidyCloudTasks.tasks(cloud_folder).get(
self.existing_task_id
)
if cloud_folder is None:
return ct.FlowSignal.NoFlow ## Task deleted somewhere else
return cloud_task
return ct.FlowSignal.FlowPending
####################
# - UI
####################
def draw_label_row(self, row: bpy.types.UILayout, text: str):
row.label(text=text)
auth_icon = 'LOCKVIEW_ON' if tdcloud.IS_AUTHENTICATED else 'LOCKVIEW_OFF'
row.label(text='', icon=auth_icon)
def draw_prelock(
self,
context: bpy.types.Context,
_: bpy.types.Context,
col: bpy.types.UILayout,
node: bpy.types.Node,
text: str,
node: bpy.types.Node, # noqa: ARG002
text: str, # noqa: ARG002
) -> None:
"""Draw a lock-immune interface for authenticating with the Tidy3D API, when the authentication is invalid."""
if not tdcloud.IS_AUTHENTICATED:
row = col.row()
row.alignment = 'CENTER'
@ -274,6 +286,7 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
)
def draw_value(self, col: bpy.types.UILayout) -> None:
"""When authenticated, draw the node UI."""
if not tdcloud.IS_AUTHENTICATED:
return
@ -306,11 +319,14 @@ class Tidy3DCloudTaskBLSocket(base.MaxwellSimSocket):
# - Socket Configuration
####################
class Tidy3DCloudTaskSocketDef(base.SocketDef):
"""Declarative object guiding the creation of a `Tidy3DCloudTask` socket."""
socket_type: ct.SocketType = ct.SocketType.Tidy3DCloudTask
should_exist: bool
def init(self, bl_socket: Tidy3DCloudTaskBLSocket) -> None:
"""Initialize the passed socket with the properties of this object."""
bl_socket.should_exist = self.should_exist
bl_socket.use_prelock = True

View File

@ -28,6 +28,7 @@ import urllib
from dataclasses import dataclass
from pathlib import Path
import bpy
import tidy3d as td
import tidy3d.web as td_web
@ -495,3 +496,28 @@ class TidyCloudTasks:
raise RuntimeError(msg) from ex
return cls.update_task(cloud_task)
####################
# - Blender UI Integration
####################
def draw_cloud_status(layout: bpy.types.UILayout) -> None:
"""Draw up-to-date information about the connection to the Tidy3D cloud to a Blender UI."""
# Connection Info
auth_icon = 'CHECKBOX_HLT' if IS_AUTHENTICATED else 'CHECKBOX_DEHLT'
conn_icon = 'CHECKBOX_HLT' if IS_ONLINE else 'CHECKBOX_DEHLT'
row = layout.row()
row.alignment = 'CENTER'
row.label(text='Cloud Status')
box = layout.box()
split = box.split(factor=0.85)
col = split.column(align=False)
col.label(text='Authed')
col.label(text='Connected')
col = split.column(align=False)
col.label(icon=auth_icon)
col.label(icon=conn_icon)

View File

@ -582,7 +582,7 @@ class BLPropType(enum.StrEnum):
return str(value)
# Single Enum: Coerce to set[str]
case BPT.SetEnum | BPT.SetDynEnum if isinstance(value, set):
case BPT.SetEnum | BPT.SetDynEnum if isinstance(value, set | frozenset):
return {str(v) for v in value}
# BLPointer: Don't Alter
@ -594,7 +594,7 @@ class BLPropType(enum.StrEnum):
case BPT.Serialized:
return serialize.encode(value).decode('utf-8')
msg = f'{self}: No encoder defined for argument {value}'
msg = f'{self}: No encoder defined for argument {value} (type={type(value)})'
raise NotImplementedError(msg)
####################
@ -603,7 +603,7 @@ class BLPropType(enum.StrEnum):
def decode(self, raw_value: typ.Any, obj_type: type) -> typ.Any: # noqa: PLR0911
"""Transform a raw value from a form read directly from the Blender property returned by `self.bl_type`, to its intended value of approximate type `obj_type`.
Notes:
Ntes:
`obj_type` is only a hint - for example, `obj_type = enum.StrEnum` is an indicator for a dynamic enum.
Its purpose is to guide ex. sizing and `StrEnum` coersion, not to guarantee a particular output type.
@ -661,7 +661,7 @@ class BLPropType(enum.StrEnum):
## -> This happens independent of whether there's a enum_cb.
case BPT.SingleEnum if isinstance(raw_value, str):
return obj_type(raw_value)
case BPT.SetEnum if isinstance(raw_value, set):
case BPT.SetEnum if isinstance(raw_value, set | frozenset):
SubStrEnum = typ.get_args(obj_type)[0]
return {SubStrEnum(v) for v in raw_value}
@ -670,7 +670,7 @@ class BLPropType(enum.StrEnum):
## -> All dynamic enums have an enum_cb, but this is merely a symptom of ^.
case BPT.SingleDynEnum if isinstance(raw_value, str):
return raw_value
case BPT.SetDynEnum if isinstance(raw_value, set):
case BPT.SetDynEnum if isinstance(raw_value, set | frozenset):
return raw_value
# BLPointer

View File

@ -0,0 +1,57 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements a `pydantic`-compatible field, `FrozenDict`, which encapsulates a `frozendict` in a serializable way, with semantics identical to `dict`."""
import typing as typ
import pydantic as pyd
from frozendict import deepfreeze, frozendict
from pydantic_core import core_schema as pyd_core_schema
class _PydanticFrozenDictAnnotation:
"""Annotated validator providing interoperability between `frozendict` and `pydantic` models.
Semantics are almost identical to `dict`, except for a chained conversion to `frozendict`.
"""
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: typ.Any, handler: pyd.GetCoreSchemaHandler
) -> pyd_core_schema.CoreSchema:
def validate_from_dict(d: dict | frozendict) -> frozendict:
return frozendict(d)
frozendict_schema = pyd_core_schema.chain_schema(
[
handler.generate_schema(dict[*typ.get_args(source_type)]),
pyd_core_schema.no_info_plain_validator_function(validate_from_dict),
pyd_core_schema.is_instance_schema(frozendict),
]
)
return pyd_core_schema.json_or_python_schema(
json_schema=frozendict_schema,
python_schema=frozendict_schema,
serialization=pyd_core_schema.plain_serializer_function_ser_schema(dict),
)
_K = typ.TypeVar('_K')
_V = typ.TypeVar('_V')
FrozenDict = typ.Annotated[frozendict[_K, _V], _PydanticFrozenDictAnnotation]
__all__ = ['deepfreeze', 'frozendict', 'FrozenDict']

View File

@ -17,7 +17,6 @@
"""Useful image processing operations for use in the addon."""
import enum
import functools
import typing as typ
import jax
@ -28,10 +27,11 @@ import matplotlib.axis as mpl_ax
import matplotlib.backends.backend_agg
import matplotlib.figure
import seaborn as sns
import sympy as sp
from blender_maxwell import contracts as ct
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils import logger
from blender_maxwell.utils import sympy_extra as spux
sns.set_theme()
@ -138,7 +138,7 @@ def rgba_image_from_2d_map(
####################
# - MPL Helpers
####################
@functools.lru_cache(maxsize=4)
# @functools.lru_cache(maxsize=4)
def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
fig = matplotlib.figure.Figure(
figsize=[width_inches, height_inches], dpi=dpi, layout='tight'
@ -154,30 +154,51 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
####################
# - Plotters
####################
def _parse_val(val):
return spux.sp_to_str(sp.S(val).n(2))
def pinned_labels(pinned_data) -> str:
return (
'\n'
+ ', '.join(
[
f'{sym.name_pretty}:' + _parse_val(val)
for sym, val in pinned_data.items()
]
)
# + ']'
)
# () ->
def plot_box_plot_1d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
x_sym, y_sym, pinned = list(data.keys())
ax.boxplot([data[y_sym]])
ax.set_title(f'{x_sym.name_pretty}{y_sym.name_pretty}')
ax.set_title(
f'{x_sym.name_pretty}{y_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
def plot_bar(data, ax: mpl_ax.Axis) -> None:
x_sym, heights_sym = list(data.keys())
x_sym, heights_sym, pinned = list(data.keys())
p = ax.bar(data[x_sym], data[heights_sym])
ax.bar_label(p, label_type='center')
ax.set_title(f'{x_sym.name_pretty} -> {heights_sym.name_pretty}')
ax.set_title(
f'{x_sym.name_pretty}{heights_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(heights_sym.plot_label)
# () -> (| sometimes complex)
def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
x_sym, y_sym, pinned = list(data.keys())
if y_sym.mathtype is spux.MathType.Complex:
ax.plot(data[x_sym], data[y_sym].real, label='')
@ -185,38 +206,47 @@ def plot_curve_2d(data, ax: mpl_ax.Axis) -> None:
ax.legend()
ax.plot(data[x_sym], data[y_sym])
ax.set_title(
f'{x_sym.name_pretty}{y_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_title(f'{x_sym.name_pretty}{y_sym.name_pretty}')
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
def plot_points_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym = list(data.keys())
x_sym, y_sym, pinned = list(data.keys())
ax.scatter(data[x_sym], data[y_sym])
ax.set_title(f'{x_sym.name_pretty}{y_sym.name_pretty}')
ax.set_title(
f'{x_sym.name_pretty}{y_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
# (, ) ->
def plot_curves_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, label_sym, y_sym = list(data.keys())
x_sym, label_sym, y_sym, pinned = list(data.keys())
for i, label in enumerate(data[label_sym]):
ax.plot(data[x_sym], data[y_sym][:, i], label=label)
ax.set_title(f'{x_sym.name_pretty}{y_sym.name_pretty}')
ax.set_title(
f'{x_sym.name_pretty}{y_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
ax.legend()
def plot_filled_curves_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, _, y_sym = list(data.keys(data))
x_sym, _, y_sym, pinned = list(data.keys(data))
ax.fill_between(data[x_sym], data[y_sym][:, 0], data[x_sym], data[y_sym][:, 1])
ax.set_title(f'{x_sym.name_pretty}{y_sym.name_pretty}')
ax.set_title(
f'{x_sym.name_pretty}{y_sym.name_pretty} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_ylabel(y_sym.plot_label)
ax.legend()
@ -224,12 +254,14 @@ def plot_filled_curves_2d(data, ax: mpl_ax.Axis) -> None:
# (, ) ->
def plot_heatmap_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym, c_sym = list(data.keys())
x_sym, y_sym, c_sym, pinned = list(data.keys())
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
ax.figure.colorbar(heatmap, cax=ax)
# ax.figure.colorbar(heatmap, ax=ax)
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) → {c_sym.plot_label}')
ax.set_title(
f'({x_sym.name_pretty}, {y_sym.name_pretty}) → {c_sym.plot_label} {pinned_labels(data[pinned])}'
)
ax.set_xlabel(x_sym.plot_label)
ax.set_xlabel(y_sym.plot_label)
ax.legend()

Some files were not shown because too many files have changed in this diff Show More