Compare commits
4 Commits
0f2f494868
...
353a2c997e
Author | SHA1 | Date |
---|---|---|
Sofus Albert Høgsbro Rose | 353a2c997e | |
Sofus Albert Høgsbro Rose | dccf952ad3 | |
Sofus Albert Høgsbro Rose | 84825c2642 | |
Sofus Albert Høgsbro Rose | f5d19abecd |
|
@ -21,11 +21,13 @@ dependencies = [
|
|||
# Pin Blender 4.1.0-Compatible Versions
|
||||
## The dependency resolver will report if anything is wonky.
|
||||
"urllib3==1.26.8",
|
||||
#"requests==2.27.1", ## Conflict with dev-dep commitizen
|
||||
#"requests==2.27.1", ## Conflict with dev-dep commitizen
|
||||
"numpy==1.24.3",
|
||||
"idna==3.3",
|
||||
#"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen
|
||||
#"charset-normalizer==2.0.10", ## Conflict with dev-dep commitizen
|
||||
"certifi==2021.10.8",
|
||||
"polars>=0.20.26",
|
||||
"seaborn[stats]>=0.13.2",
|
||||
]
|
||||
## When it comes to dev-dep conflicts:
|
||||
## -> It's okay to leave Blender-pinned deps out of prod; Blender still has them.
|
||||
|
|
|
@ -81,6 +81,7 @@ locket==1.0.0
|
|||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
matplotlib==3.8.3
|
||||
# via seaborn
|
||||
# via tidy3d
|
||||
ml-dtypes==0.4.0
|
||||
# via jax
|
||||
|
@ -102,8 +103,11 @@ numpy==1.24.3
|
|||
# via ml-dtypes
|
||||
# via numba
|
||||
# via opt-einsum
|
||||
# via patsy
|
||||
# via scipy
|
||||
# via seaborn
|
||||
# via shapely
|
||||
# via statsmodels
|
||||
# via tidy3d
|
||||
# via trimesh
|
||||
# via xarray
|
||||
|
@ -114,15 +118,21 @@ packaging==24.0
|
|||
# via dask
|
||||
# via h5netcdf
|
||||
# via matplotlib
|
||||
# via statsmodels
|
||||
# via xarray
|
||||
pandas==2.2.1
|
||||
# via seaborn
|
||||
# via statsmodels
|
||||
# via xarray
|
||||
partd==1.4.1
|
||||
# via dask
|
||||
patsy==0.5.6
|
||||
# via statsmodels
|
||||
pillow==10.2.0
|
||||
# via matplotlib
|
||||
platformdirs==4.2.1
|
||||
# via virtualenv
|
||||
polars==0.20.26
|
||||
pre-commit==3.7.0
|
||||
prompt-toolkit==3.0.36
|
||||
# via questionary
|
||||
|
@ -166,13 +176,19 @@ s3transfer==0.5.2
|
|||
scipy==1.12.0
|
||||
# via jax
|
||||
# via jaxlib
|
||||
# via seaborn
|
||||
# via statsmodels
|
||||
# via tidy3d
|
||||
seaborn==0.13.2
|
||||
setuptools==69.5.1
|
||||
# via nodeenv
|
||||
shapely==2.0.3
|
||||
# via tidy3d
|
||||
six==1.16.0
|
||||
# via patsy
|
||||
# via python-dateutil
|
||||
statsmodels==0.14.2
|
||||
# via seaborn
|
||||
sympy==1.12
|
||||
termcolor==2.4.0
|
||||
# via commitizen
|
||||
|
|
|
@ -59,6 +59,7 @@ llvmlite==0.42.0
|
|||
locket==1.0.0
|
||||
# via partd
|
||||
matplotlib==3.8.3
|
||||
# via seaborn
|
||||
# via tidy3d
|
||||
ml-dtypes==0.4.0
|
||||
# via jax
|
||||
|
@ -78,8 +79,11 @@ numpy==1.24.3
|
|||
# via ml-dtypes
|
||||
# via numba
|
||||
# via opt-einsum
|
||||
# via patsy
|
||||
# via scipy
|
||||
# via seaborn
|
||||
# via shapely
|
||||
# via statsmodels
|
||||
# via tidy3d
|
||||
# via trimesh
|
||||
# via xarray
|
||||
|
@ -89,13 +93,19 @@ packaging==24.0
|
|||
# via dask
|
||||
# via h5netcdf
|
||||
# via matplotlib
|
||||
# via statsmodels
|
||||
# via xarray
|
||||
pandas==2.2.1
|
||||
# via seaborn
|
||||
# via statsmodels
|
||||
# via xarray
|
||||
partd==1.4.1
|
||||
# via dask
|
||||
patsy==0.5.6
|
||||
# via statsmodels
|
||||
pillow==10.2.0
|
||||
# via matplotlib
|
||||
polars==0.20.26
|
||||
pydantic==2.7.1
|
||||
# via tidy3d
|
||||
pydantic-core==2.18.2
|
||||
|
@ -131,11 +141,17 @@ s3transfer==0.5.2
|
|||
scipy==1.12.0
|
||||
# via jax
|
||||
# via jaxlib
|
||||
# via seaborn
|
||||
# via statsmodels
|
||||
# via tidy3d
|
||||
seaborn==0.13.2
|
||||
shapely==2.0.3
|
||||
# via tidy3d
|
||||
six==1.16.0
|
||||
# via patsy
|
||||
# via python-dateutil
|
||||
statsmodels==0.14.2
|
||||
# via seaborn
|
||||
sympy==1.12
|
||||
tidy3d==2.6.3
|
||||
toml==0.10.2
|
||||
|
|
|
@ -41,6 +41,9 @@ class OperatorType(enum.StrEnum):
|
|||
SocketCloudAuthenticate = enum.auto()
|
||||
SocketReloadCloudFolderList = enum.auto()
|
||||
|
||||
# Node: ExportDataFile
|
||||
NodeExportDataFile = enum.auto()
|
||||
|
||||
# Node: Tidy3DWebImporter
|
||||
NodeLoadCloudSim = enum.auto()
|
||||
|
||||
|
|
|
@ -47,8 +47,8 @@ from .flow_kinds import (
|
|||
CapabilitiesFlow,
|
||||
FlowKind,
|
||||
InfoFlow,
|
||||
LazyArrayRangeFlow,
|
||||
LazyValueFuncFlow,
|
||||
RangeFlow,
|
||||
FuncFlow,
|
||||
ParamsFlow,
|
||||
ScalingMode,
|
||||
ValueFlow,
|
||||
|
@ -59,6 +59,7 @@ from .mobj_types import ManagedObjType
|
|||
from .node_types import NodeType
|
||||
from .sim_types import (
|
||||
BoundCondType,
|
||||
DataFileFormat,
|
||||
NewSimCloudTask,
|
||||
SimAxisDir,
|
||||
SimFieldPols,
|
||||
|
@ -103,6 +104,7 @@ __all__ = [
|
|||
'BLSocketType',
|
||||
'NodeType',
|
||||
'BoundCondType',
|
||||
'DataFileFormat',
|
||||
'NewSimCloudTask',
|
||||
'SimAxisDir',
|
||||
'SimFieldPols',
|
||||
|
@ -116,8 +118,8 @@ __all__ = [
|
|||
'CapabilitiesFlow',
|
||||
'FlowKind',
|
||||
'InfoFlow',
|
||||
'LazyArrayRangeFlow',
|
||||
'LazyValueFuncFlow',
|
||||
'RangeFlow',
|
||||
'FuncFlow',
|
||||
'ParamsFlow',
|
||||
'ScalingMode',
|
||||
'ValueFlow',
|
||||
|
|
|
@ -18,8 +18,8 @@ from .array import ArrayFlow
|
|||
from .capabilities import CapabilitiesFlow
|
||||
from .flow_kinds import FlowKind
|
||||
from .info import InfoFlow
|
||||
from .lazy_array_range import LazyArrayRangeFlow, ScalingMode
|
||||
from .lazy_value_func import LazyValueFuncFlow
|
||||
from .lazy_range import RangeFlow, ScalingMode
|
||||
from .lazy_func import FuncFlow
|
||||
from .params import ParamsFlow
|
||||
from .value import ValueFlow
|
||||
|
||||
|
@ -28,9 +28,9 @@ __all__ = [
|
|||
'CapabilitiesFlow',
|
||||
'FlowKind',
|
||||
'InfoFlow',
|
||||
'LazyArrayRangeFlow',
|
||||
'RangeFlow',
|
||||
'ScalingMode',
|
||||
'LazyValueFuncFlow',
|
||||
'FuncFlow',
|
||||
'ParamsFlow',
|
||||
'ValueFlow',
|
||||
]
|
||||
|
|
|
@ -29,9 +29,12 @@ from blender_maxwell.utils import logger
|
|||
log = logger.get(__name__)
|
||||
|
||||
|
||||
# TODO: Our handling of 'is_sorted' is sloppy and probably wrong.
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class ArrayFlow:
|
||||
"""A simple, flat array of values with an optionally-attached unit.
|
||||
"""A homogeneous, realized array of numerical values with an optionally-attached unit and sort-tracking.
|
||||
|
||||
While the principle is simple, arrays-with-units ends up being a powerful basis for derived and computed features/methods/processing.
|
||||
|
||||
Attributes:
|
||||
values: An ND array-like object of arbitrary numerical type.
|
||||
|
@ -44,13 +47,97 @@ class ArrayFlow:
|
|||
|
||||
is_sorted: bool = False
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def is_symbolic(self) -> bool:
|
||||
"""Always False, as ArrayFlows are never unrealized."""
|
||||
return False
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Outer length of the contained array."""
|
||||
return len(self.values)
|
||||
|
||||
@functools.cached_property
|
||||
def mathtype(self) -> spux.MathType:
|
||||
"""Deduce the `spux.MathType` of the first element of the contained array.
|
||||
|
||||
This is generally a heuristic, but because `jax` enforces homogeneous arrays, this is actually a well-defined approach.
|
||||
"""
|
||||
return spux.MathType.from_pytype(type(self.values.item(0)))
|
||||
|
||||
@functools.cached_property
|
||||
def physical_type(self) -> spux.MathType:
|
||||
"""Deduce the `spux.PhysicalType` of the unit."""
|
||||
return spux.PhysicalType.from_unit(self.unit)
|
||||
|
||||
####################
|
||||
# - Array Features
|
||||
####################
|
||||
@property
|
||||
def realize_array(self) -> jtyp.Shaped[jtyp.Array, '...']:
|
||||
"""Standardized access to `self.values`."""
|
||||
return self.values
|
||||
|
||||
@functools.cached_property
|
||||
def shape(self) -> int:
|
||||
"""Shape of the contained array."""
|
||||
return self.values.shape
|
||||
|
||||
def __getitem__(self, subscript: slice) -> typ.Self | spux.SympyExpr:
|
||||
"""Implement indexing and slicing in a sane way.
|
||||
|
||||
- **Integer Index**: For scalar output, return a `sympy` expression of the scalar multiplied by the unit, else just a sympy expression of the value.
|
||||
- **Slice**: Slice the internal array directly, and wrap the result in a new `ArrayFlow`.
|
||||
"""
|
||||
if isinstance(subscript, slice):
|
||||
return ArrayFlow(
|
||||
values=self.values[subscript],
|
||||
unit=self.unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
if isinstance(subscript, int):
|
||||
value = self.values[subscript]
|
||||
if len(value.shape) == 0:
|
||||
return value * self.unit if self.unit is not None else sp.S(value)
|
||||
return ArrayFlow(values=value, unit=self.unit, is_sorted=self.is_sorted)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
####################
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
"""Apply an order-preserving function to each element of the array, then (optionally) transform the result w/new unit and/or order.
|
||||
|
||||
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
|
||||
|
||||
Parameters:
|
||||
rescale_func: An **order-preserving** function to apply to each array element.
|
||||
reverse: Whether to reverse the order of the result.
|
||||
new_unit: An (optional) new unit to scale the result to.
|
||||
"""
|
||||
# Compile JAX-Compatible Rescale Function
|
||||
a = self.mathtype.sp_symbol_a
|
||||
rescale_expr = (
|
||||
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
|
||||
if self.unit is not None
|
||||
else rescale_func(a)
|
||||
)
|
||||
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
|
||||
values = _rescale_func(self.values)
|
||||
|
||||
# Return ArrayFlow
|
||||
return ArrayFlow(
|
||||
values=values[::-1] if reverse else values,
|
||||
unit=new_unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
||||
"""Find the index of the value that is closest to the given value.
|
||||
|
||||
|
@ -88,56 +175,26 @@ class ArrayFlow:
|
|||
|
||||
return right_idx
|
||||
|
||||
def correct_unit(self, corrected_unit: spu.Quantity) -> typ.Self:
|
||||
if self.unit is not None:
|
||||
return ArrayFlow(
|
||||
values=self.values, unit=corrected_unit, is_sorted=self.is_sorted
|
||||
)
|
||||
####################
|
||||
# - Unit Transforms
|
||||
####################
|
||||
def correct_unit(self, unit: spux.Unit) -> typ.Self:
|
||||
"""Simply replace the existing unit with the given one.
|
||||
|
||||
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||
raise ValueError(msg)
|
||||
Parameters:
|
||||
corrected_unit: The new unit to insert.
|
||||
**MUST** be associable with a well-defined `PhysicalType`.
|
||||
"""
|
||||
return ArrayFlow(values=self.values, unit=unit, is_sorted=self.is_sorted)
|
||||
|
||||
def rescale_to_unit(self, unit: spu.Quantity | None) -> typ.Self:
|
||||
## TODO: Cache by unit would be a very nice speedup for Viz node.
|
||||
if self.unit is not None:
|
||||
return ArrayFlow(
|
||||
values=float(spux.scaling_factor(self.unit, unit)) * self.values,
|
||||
unit=unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
def rescale_to_unit(self, new_unit: spux.Unit | None) -> typ.Self:
|
||||
"""Rescale the `ArrayFlow` to be expressed in the given unit.
|
||||
|
||||
if unit is None:
|
||||
return self
|
||||
|
||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
# Compile JAX-Compatible Rescale Function
|
||||
a = sp.Symbol('a')
|
||||
rescale_expr = (
|
||||
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
|
||||
if self.unit is not None
|
||||
else rescale_func(a * self.unit)
|
||||
)
|
||||
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
|
||||
values = _rescale_func(self.values)
|
||||
|
||||
# Return ArrayFlow
|
||||
return ArrayFlow(
|
||||
values=values[::-1] if reverse else values,
|
||||
unit=new_unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
def __getitem__(self, subscript: slice):
|
||||
if isinstance(subscript, slice):
|
||||
return ArrayFlow(
|
||||
values=self.values[subscript],
|
||||
unit=self.unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
Parameters:
|
||||
corrected_unit: The new unit to insert.
|
||||
**MUST** be associable with a well-defined `PhysicalType`.
|
||||
"""
|
||||
return self.rescale(lambda v: v, new_unit=new_unit)
|
||||
|
||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# blender_maxwell
|
||||
# Copyright (C) 2024 blender_maxwell Project Contributors
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import typing as typ
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from . import FlowKind
|
||||
|
||||
|
||||
class ExprInfo(typ.TypedDict):
|
||||
active_kind: FlowKind
|
||||
size: spux.NumberSize1D
|
||||
mathtype: spux.MathType
|
||||
physical_type: spux.PhysicalType
|
||||
|
||||
# Value
|
||||
default_value: spux.SympyExpr
|
||||
|
||||
# Range
|
||||
default_min: spux.SympyExpr
|
||||
default_max: spux.SympyExpr
|
||||
default_steps: int
|
|
@ -19,6 +19,7 @@ import typing as typ
|
|||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils.staticproperty import staticproperty
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -40,9 +41,9 @@ class FlowKind(enum.StrEnum):
|
|||
Array: An object with dimensions, and possibly a unit.
|
||||
Whenever a `Value` is defined, a single-element `list` will also be generated by default as `Array`
|
||||
However, for any other array-like variants (or sockets that only represent array-like objects), `Array` should be defined manually.
|
||||
LazyValueFunc: A composable function.
|
||||
Func: A composable function.
|
||||
Can be used to represent computations for which all data is not yet known, or for which just-in-time compilation can drastically increase performance.
|
||||
LazyArrayRange: An object that generates an `Array` from range information (start/stop/step/spacing).
|
||||
Range: An object that generates an `Array` from range information (start/stop/step/spacing).
|
||||
This should be used instead of `Array` whenever possible.
|
||||
Param: A dictionary providing particular parameters for a lazy value.
|
||||
Info: An dictionary providing extra context about any aspect of flow.
|
||||
|
@ -51,16 +52,71 @@ class FlowKind(enum.StrEnum):
|
|||
Capabilities = enum.auto()
|
||||
|
||||
# Values
|
||||
Value = enum.auto()
|
||||
Array = enum.auto()
|
||||
Value = enum.auto() ## 'value'
|
||||
Array = enum.auto() ## 'array'
|
||||
|
||||
# Lazy
|
||||
LazyValueFunc = enum.auto()
|
||||
LazyArrayRange = enum.auto()
|
||||
Func = enum.auto() ## 'lazy_func'
|
||||
Range = enum.auto() ## 'lazy_range'
|
||||
|
||||
# Auxiliary
|
||||
Params = enum.auto()
|
||||
Info = enum.auto()
|
||||
Params = enum.auto() ## 'params'
|
||||
Info = enum.auto() ## 'info'
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
return {
|
||||
FlowKind.Capabilities: 'Capabilities',
|
||||
# Values
|
||||
FlowKind.Value: 'Value',
|
||||
FlowKind.Array: 'Array',
|
||||
# Lazy
|
||||
FlowKind.Range: 'Range',
|
||||
FlowKind.Func: 'Func',
|
||||
# Auxiliary
|
||||
FlowKind.Params: 'Params',
|
||||
FlowKind.Info: 'Info',
|
||||
}[v]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(_: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
####################
|
||||
# - Static Properties
|
||||
####################
|
||||
@staticproperty
|
||||
def active_kinds() -> list[typ.Self]:
|
||||
"""Return a list of `FlowKind`s that are able to be considered "active".
|
||||
|
||||
"Active" `FlowKind`s are considered the primary data type of a socket's flow.
|
||||
For example, for sockets to be linkeable, their active `FlowKind` must generally match.
|
||||
"""
|
||||
return [
|
||||
FlowKind.Value,
|
||||
FlowKind.Array,
|
||||
FlowKind.Range,
|
||||
FlowKind.Func,
|
||||
]
|
||||
|
||||
@property
|
||||
def socket_shape(self) -> str:
|
||||
"""Return the socket shape associated with this `FlowKind`.
|
||||
|
||||
**ONLY** valid for `FlowKind`s that can be considered "active".
|
||||
|
||||
Raises:
|
||||
ValueError: If this `FlowKind` cannot ever be considered "active".
|
||||
"""
|
||||
return {
|
||||
FlowKind.Value: 'CIRCLE',
|
||||
FlowKind.Array: 'SQUARE',
|
||||
FlowKind.Range: 'SQUARE',
|
||||
FlowKind.Func: 'DIAMOND',
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Class Methods
|
||||
|
@ -69,7 +125,7 @@ class FlowKind(enum.StrEnum):
|
|||
def scale_to_unit_system(
|
||||
cls,
|
||||
kind: typ.Self,
|
||||
flow_obj,
|
||||
flow_obj: spux.SympyExpr,
|
||||
unit_system: spux.UnitSystem,
|
||||
):
|
||||
# log.debug('%s: Scaling "%s" to Unit System', kind, str(flow_obj))
|
||||
|
@ -79,7 +135,7 @@ class FlowKind(enum.StrEnum):
|
|||
flow_obj,
|
||||
unit_system,
|
||||
)
|
||||
if kind == FlowKind.LazyArrayRange:
|
||||
if kind == FlowKind.Range:
|
||||
return flow_obj.rescale_to_unit_system(unit_system)
|
||||
|
||||
if kind == FlowKind.Params:
|
||||
|
@ -87,43 +143,3 @@ class FlowKind(enum.StrEnum):
|
|||
|
||||
msg = 'Tried to scale unknown kind'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Computed
|
||||
####################
|
||||
@property
|
||||
def flow_kind(self) -> str:
|
||||
return {
|
||||
FlowKind.Value: FlowKind.Value,
|
||||
FlowKind.Array: FlowKind.Array,
|
||||
FlowKind.LazyValueFunc: FlowKind.LazyValueFunc,
|
||||
FlowKind.LazyArrayRange: FlowKind.LazyArrayRange,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def socket_shape(self) -> str:
|
||||
return {
|
||||
FlowKind.Value: 'CIRCLE',
|
||||
FlowKind.Array: 'SQUARE',
|
||||
FlowKind.LazyArrayRange: 'SQUARE',
|
||||
FlowKind.LazyValueFunc: 'DIAMOND',
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Blender Enum
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
return {
|
||||
FlowKind.Capabilities: 'Capabilities',
|
||||
FlowKind.Value: 'Value',
|
||||
FlowKind.Array: 'Array',
|
||||
FlowKind.LazyArrayRange: 'Range',
|
||||
FlowKind.LazyValueFunc: 'Func',
|
||||
FlowKind.Params: 'Parameters',
|
||||
FlowKind.Info: 'Information',
|
||||
}[v]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(_: typ.Self) -> str:
|
||||
return ''
|
||||
|
|
|
@ -18,246 +18,247 @@ import dataclasses
|
|||
import functools
|
||||
import typing as typ
|
||||
|
||||
import jax
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .lazy_array_range import LazyArrayRangeFlow
|
||||
from .lazy_range import RangeFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
LabelArray: typ.TypeAlias = list[str]
|
||||
|
||||
# IndexArray: Identifies Discrete Dimension Values
|
||||
## -> ArrayFlow (rat|real): Index by particular, not-guaranteed-uniform index.
|
||||
## -> RangeFlow (rat|real): Index by unrealized array scaled between boundaries.
|
||||
## -> LabelArray (int): For int-index arrays, interpret using these labels.
|
||||
## -> None: Non-Discrete/unrealized indexing; use 'dim.domain'.
|
||||
IndexArray: typ.TypeAlias = ArrayFlow | RangeFlow | LabelArray | None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class InfoFlow:
|
||||
####################
|
||||
# - Covariant Input
|
||||
####################
|
||||
dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
|
||||
default_factory=dict
|
||||
) ## TODO: Rename to dim_idxs
|
||||
"""Contains dimension and output information characterizing the array produced by a parallel `FuncFlow`.
|
||||
|
||||
@functools.cached_property
|
||||
def dim_has_coords(self) -> dict[str, int]:
|
||||
return {
|
||||
dim_name: not (
|
||||
isinstance(dim_idx, LazyArrayRangeFlow)
|
||||
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
|
||||
)
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
}
|
||||
Functionally speaking, `InfoFlow` provides essential mathematical and physical context to raw array data, with terminology adapted from multilinear algebra.
|
||||
|
||||
@functools.cached_property
|
||||
def dim_lens(self) -> dict[str, int]:
|
||||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||
# From Arrays to Tensors
|
||||
The best way to illustrate how it works is to specify how raw-array concepts map to an array described by an `InfoFlow`:
|
||||
|
||||
@functools.cached_property
|
||||
def dim_mathtypes(self) -> dict[str, spux.MathType]:
|
||||
return {
|
||||
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
|
||||
}
|
||||
- **Index**: In raw arrays, the "index" is generally constrained to an integer ring, and has no semantic meaning.
|
||||
**(Covariant) Dimension**: The "dimension" is an named "index array", which assigns each integer index a **scalar** value of particular mathematical type, name, and unit (if not unitless).
|
||||
- **Value**: In raw arrays, the "value" is some particular computational type, or another raw array.
|
||||
**(Contravariant) Output**: The "output" is a strictly named, sized object that can only be produced
|
||||
|
||||
@functools.cached_property
|
||||
def dim_units(self) -> dict[str, spux.Unit]:
|
||||
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
|
||||
In essence, `InfoFlow` allows us to treat raw data as a tensor, then operate on its dimensionality as split into parts whose transform varies _with_ the output (aka. a _covariant_ index), and parts whose transform varies _against_ the output (aka. _contravariant_ value).
|
||||
|
||||
@functools.cached_property
|
||||
def dim_physical_types(self) -> dict[str, spux.PhysicalType]:
|
||||
return {
|
||||
dim_name: spux.PhysicalType.from_unit(dim_idx.unit)
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
}
|
||||
## Benefits
|
||||
The reasons to do this are numerous:
|
||||
|
||||
@functools.cached_property
|
||||
def dim_idx_arrays(self) -> list[jax.Array]:
|
||||
return [
|
||||
dim_idx.realize().values
|
||||
if isinstance(dim_idx, LazyArrayRangeFlow)
|
||||
else dim_idx.values
|
||||
for dim_idx in self.dim_idx.values()
|
||||
]
|
||||
- **Clarity**: Using `InfoFlow`, it's easy to understand what the data is, and what can be done to it, making it much easier to implement complex operations in math nodes without sacrificing the user's mental model.
|
||||
- **Zero-Cost Operations**: Transforming indices, "folding" dimensions into the output, and other such operations don't actually do anything to the data, enabling a lot of operations to feel "free" in terms of performance.
|
||||
- **Semantic Indexing**: Using `InfoFlow`, it's easy to index and slice arrays using ex. nanometer vacuum wavelengths, instead of arbitrary integers.
|
||||
"""
|
||||
|
||||
####################
|
||||
# - Contravariant Output
|
||||
# - Dimensions: Covariant Index
|
||||
####################
|
||||
# Output Information
|
||||
## TODO: Add PhysicalType
|
||||
output_name: str = dataclasses.field(default_factory=list)
|
||||
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
|
||||
output_mathtype: spux.MathType = dataclasses.field()
|
||||
output_unit: spux.Unit | None = dataclasses.field()
|
||||
|
||||
@property
|
||||
def output_shape_len(self) -> int:
|
||||
if self.output_shape is None:
|
||||
return 0
|
||||
return len(self.output_shape)
|
||||
|
||||
# Pinned Dimension Information
|
||||
## TODO: Add PhysicalType
|
||||
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
|
||||
dims: dict[sim_symbols.SimSymbol, IndexArray] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
|
||||
|
||||
@functools.cached_property
|
||||
def last_dim(self) -> sim_symbols.SimSymbol | None:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
"""
|
||||
if self.dims:
|
||||
return next(iter(self.dims.keys()))
|
||||
return None
|
||||
|
||||
@functools.cached_property
|
||||
def last_dim(self) -> sim_symbols.SimSymbol | None:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
"""
|
||||
if self.dims:
|
||||
return list(self.dims.keys())[-1]
|
||||
return None
|
||||
|
||||
def dim_axis(self, dim: sim_symbols.SimSymbol) -> int:
|
||||
"""The integer axis occupied by the dimension.
|
||||
|
||||
Can be used to index `.shape` of the represented raw array.
|
||||
"""
|
||||
return list(self.dims.keys()).index(dim)
|
||||
|
||||
def has_idx_cont(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||
"""Whether the dim's index is continuous, and therefore index array.
|
||||
|
||||
This happens when the dimension is generated from a symbolic function, as opposed to from discrete observations.
|
||||
In these cases, the `SimSymbol.domain` of the dimension should be used to determine the overall domain of validity.
|
||||
|
||||
Other than that, it's up to the user to select a particular way of indexing.
|
||||
"""
|
||||
return self.dims[dim] is None
|
||||
|
||||
def has_idx_discrete(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||
"""Whether the (rat|real) dim is indexed by an `ArrayFlow` / `RangeFlow`."""
|
||||
return isinstance(self.dims[dim], ArrayFlow | RangeFlow)
|
||||
|
||||
def has_idx_labels(self, dim: sim_symbols.SimSymbol) -> bool:
|
||||
"""Whether the (int) dim is indexed by a `LabelArray`."""
|
||||
if dim.mathtype is spux.MathType.Integer:
|
||||
return isinstance(self.dims[dim], list)
|
||||
return False
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
# - Output: Contravariant Value
|
||||
####################
|
||||
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self:
|
||||
output: sim_symbols.SimSymbol
|
||||
|
||||
####################
|
||||
# - Pinned Dimension Values
|
||||
####################
|
||||
## -> Whenever a dimension is deleted, we retain what that index value was.
|
||||
## -> This proves to be very helpful for clear visualization.
|
||||
pinned_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
####################
|
||||
# - Operations: Dimensions
|
||||
####################
|
||||
def prepend_dim(
|
||||
self, dim: sim_symbols.SimSymbol, dim_idx: sim_symbols.SimSymbol
|
||||
) -> typ.Self:
|
||||
"""Insert a new dimension at index 0."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
dim_idx={
|
||||
_dim_name: (
|
||||
dim_idx
|
||||
if _dim_name != dim_name
|
||||
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||
)
|
||||
for _dim_name, dim_idx in self.dim_idx.items()
|
||||
dims={dim: dim_idx} | self.dims,
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def slice_dim(
|
||||
self, dim: sim_symbols.SimSymbol, slice_tuple: tuple[int, int, int]
|
||||
) -> typ.Self:
|
||||
"""Slice a dimensional array by-index along a particular dimension."""
|
||||
return InfoFlow(
|
||||
dims={
|
||||
_dim: dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||
if _dim == dim
|
||||
else _dim
|
||||
for _dim, dim_idx in self.dims.items()
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def replace_dim(
|
||||
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow]
|
||||
self,
|
||||
old_dim: sim_symbols.SimSymbol,
|
||||
new_dim: sim_symbols.SimSymbol,
|
||||
new_dim_idx: IndexArray,
|
||||
) -> typ.Self:
|
||||
"""Replace a dimension (and its indexing) with a new name and index array/range."""
|
||||
"""Replace a dimension entirely, in-place, including symbol and index array."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=[
|
||||
dim_name if dim_name != old_dim_name else new_dim_idx[0]
|
||||
for dim_name in self.dim_names
|
||||
],
|
||||
dim_idx={
|
||||
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
|
||||
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
|
||||
dims={
|
||||
(new_dim if _dim == old_dim else _dim): (
|
||||
new_dim_idx if _dim == old_dim else _dim
|
||||
)
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
for _dim, dim_idx in self.dims.items()
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
|
||||
def replace_dims(
|
||||
self, new_dims: dict[sim_symbols.SimSymbol, IndexArray]
|
||||
) -> typ.Self:
|
||||
"""Replace several dimensional indices with new index arrays/ranges."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
dim_idx={
|
||||
_dim_name: new_dim_idxs.get(_dim_name, dim_idx)
|
||||
for _dim_name, dim_idx in self.dim_idx.items()
|
||||
dims={
|
||||
dim: new_dims.get(dim, dim_idx) for dim, dim_idx in self.dim_idx.items()
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def delete_dimension(self, dim_name: str) -> typ.Self:
|
||||
"""Delete a dimension."""
|
||||
def delete_dim(
|
||||
self, dim_to_remove: sim_symbols.SimSymbol, pin_idx: int | None = None
|
||||
) -> typ.Self:
|
||||
"""Delete a dimension, optionally pinning the value of an index from that dimension."""
|
||||
new_pin = (
|
||||
{dim_to_remove: self.dims[dim_to_remove][pin_idx]}
|
||||
if pin_idx is not None
|
||||
else {}
|
||||
)
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=[
|
||||
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
|
||||
],
|
||||
dim_idx={
|
||||
_dim_name: dim_idx
|
||||
for _dim_name, dim_idx in self.dim_idx.items()
|
||||
if _dim_name != dim_name
|
||||
dims={
|
||||
dim: dim_idx
|
||||
for dim, dim_idx in self.dims.items()
|
||||
if dim != dim_to_remove
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values | new_pin,
|
||||
)
|
||||
|
||||
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
|
||||
"""Swap the position of two dimensions."""
|
||||
def swap_dimensions(self, dim_0: str, dim_1: str) -> typ.Self:
|
||||
"""Swap the positions of two dimensions."""
|
||||
|
||||
# Compute Swapped Dimension Name List
|
||||
# Swapped Dimension Keys
|
||||
def name_swapper(dim_name):
|
||||
return (
|
||||
dim_name
|
||||
if dim_name not in [dim_0_name, dim_1_name]
|
||||
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
|
||||
if dim_name not in [dim_0, dim_1]
|
||||
else {dim_0: dim_1, dim_1: dim_0}[dim_name]
|
||||
)
|
||||
|
||||
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
|
||||
swapped_dim_keys = [name_swapper(dim) for dim in self.dims]
|
||||
|
||||
# Compute Info
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=dim_names,
|
||||
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
dims={dim_key: self.dims[dim_key] for dim_key in swapped_dim_keys},
|
||||
output=self.output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
|
||||
"""Set the MathType of the output."""
|
||||
####################
|
||||
# - Operations: Output
|
||||
####################
|
||||
def update_output(self, **kwargs) -> typ.Self:
|
||||
"""Passthrough to `SimSymbol.update()` method on `self.output`."""
|
||||
return InfoFlow(
|
||||
dim_names=self.dim_names,
|
||||
dim_idx=self.dim_idx,
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
dims=self.dims,
|
||||
output=self.output.update(**kwargs),
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
||||
def collapse_output(
|
||||
self,
|
||||
collapsed_name: str,
|
||||
collapsed_mathtype: spux.MathType,
|
||||
collapsed_unit: spux.Unit,
|
||||
) -> typ.Self:
|
||||
"""Replace the (scalar) output with the given corrected values."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
dim_idx=self.dim_idx,
|
||||
output_name=collapsed_name,
|
||||
output_shape=None,
|
||||
output_mathtype=collapsed_mathtype,
|
||||
output_unit=collapsed_unit,
|
||||
)
|
||||
####################
|
||||
# - Operations: Fold
|
||||
####################
|
||||
def fold_last_input(self):
|
||||
"""Fold the last input dimension into the output."""
|
||||
last_key = list(self.dims.keys())[-1]
|
||||
last_idx = list(self.dims.values())[-1]
|
||||
|
||||
rows = self.output.rows
|
||||
cols = self.output.cols
|
||||
match (rows, cols):
|
||||
case (1, 1):
|
||||
new_output = self.output.set_size(len(last_idx), 1)
|
||||
case (_, 1):
|
||||
new_output = self.output.set_size(rows, len(last_idx))
|
||||
case (1, _):
|
||||
new_output = self.output.set_size(len(last_idx), cols)
|
||||
case (_, _):
|
||||
raise NotImplementedError ## Not yet :)
|
||||
|
||||
@functools.cached_property
|
||||
def shift_last_input(self):
|
||||
"""Shift the last input dimension to the output."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names[:-1],
|
||||
dim_idx={
|
||||
dim_name: dim_idx
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
if dim_name != self.dim_names[-1]
|
||||
dims={
|
||||
dim: dim_idx for dim, dim_idx in self.dims.items() if dim != last_key
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=(
|
||||
(self.dim_lens[self.dim_names[-1]],)
|
||||
if self.output_shape is None
|
||||
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
|
||||
),
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
output=new_output,
|
||||
pinned_values=self.pinned_values,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,440 @@
|
|||
# blender_maxwell
|
||||
# Copyright (C) 2024 blender_maxwell Project Contributors
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .params import ParamsFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class FuncFlow:
|
||||
r"""Defines a flow of data as incremental function composition.
|
||||
|
||||
For specific math system usage instructions, please consult the documentation of relevant nodes.
|
||||
|
||||
# Introduction
|
||||
When using nodes to do math, it becomes immediately obvious to express **flows of data as composed function chains**.
|
||||
Doing so has several advantages:
|
||||
|
||||
- **Interactive**: Since no large-array math is being done, the UI can be designed to feel fast and snappy.
|
||||
- **Symbolic**: Since no numerical math is being done yet, we can choose to keep our input parameters as symbolic variables with no performance impact.
|
||||
- **Performant**: Since no operations are happening, the UI feels fast and snappy.
|
||||
|
||||
## Strongly Related FlowKinds
|
||||
For doing math, `Func` relies on two other `FlowKind`s, which must run in parallel:
|
||||
|
||||
- `FlowKind.Info`: Tracks the name, `spux.MathType`, unit (if any), length, and index coordinates for the raw data object produced by `Func`.
|
||||
- `FlowKind.Params`: Tracks the particular values of input parameters to the lazy function, each of which can also be symbolic.
|
||||
|
||||
For more, please see the documentation for each.
|
||||
|
||||
## Non-Mathematical Use
|
||||
Of course, there are many interesting uses of incremental function composition that aren't mathematical.
|
||||
|
||||
For such cases, the usage is identical, but the complexity is lessened; for example, `Info` no longer effectively needs to flow in parallel.
|
||||
|
||||
|
||||
|
||||
# Lazy Math: Theoretical Foundation
|
||||
This `FlowKind` is the critical component of a functional-inspired system for lazy multilinear math.
|
||||
Thus, it makes sense to describe the math system here.
|
||||
|
||||
## `depth=0`: Root Function
|
||||
To start a composition chain, a function with no inputs must be defined as the "root", or "bottom".
|
||||
|
||||
$$
|
||||
f_0:\ \ \ \ \biggl(
|
||||
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
|
||||
\underbrace{
|
||||
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
|
||||
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
|
||||
...,
|
||||
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
|
||||
}_{\texttt{kwargs}}
|
||||
\biggr) \to \text{output}_0
|
||||
$$
|
||||
|
||||
In Python, such a construction would look like this:
|
||||
|
||||
```python
|
||||
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
|
||||
## 'A0', 'KV0' are of length 'p' and 'q'
|
||||
def f_0(*args, **kwargs): ...
|
||||
|
||||
lazy_func_0 = FuncFlow(
|
||||
func=f_0,
|
||||
func_args=[(a_i, type(a_i)) for a_i in A0],
|
||||
func_kwargs={k: v for k,v in KV0},
|
||||
)
|
||||
output_0 = lazy_func.func(*A0_computed, **KV0_computed)
|
||||
```
|
||||
|
||||
## `depth>0`: Composition Chaining
|
||||
So far, so easy.
|
||||
Now, let's add a function that uses the result of $f_0$, without yet computing it.
|
||||
|
||||
$$
|
||||
f_1:\ \ \ \ \biggl(
|
||||
f_0(...),\ \
|
||||
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
|
||||
\underbrace{\biggl\{
|
||||
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
|
||||
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
|
||||
\biggr) \to \text{output}_1
|
||||
$$
|
||||
|
||||
Note:
|
||||
- $f_1$ must take the arguments of both $f_0$ and $f_1$.
|
||||
- The complexity is getting notationally complex; we already have to use `...` to represent "the last function's arguments".
|
||||
|
||||
In other words, **there's suddenly a lot to manage**.
|
||||
Even worse, the bigger the $n$, the more complexity we must real with.
|
||||
|
||||
This is where the Python version starts to show its purpose:
|
||||
|
||||
```python
|
||||
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
|
||||
## 'A1', 'KV1' are therefore of length 'r' and 's'
|
||||
def f_1(output_0, *args, **kwargs): ...
|
||||
|
||||
lazy_func_1 = lazy_func_0.compose_within(
|
||||
enclosing_func=f_1,
|
||||
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
|
||||
enclosing_func_kwargs={k: type(v) for k,v in K1},
|
||||
)
|
||||
|
||||
A_computed = A0_computed + A1_computed
|
||||
KW_computed = KV0_computed + KV1_computed
|
||||
output_1 = lazy_func_1.func(*A_computed, **KW_computed)
|
||||
```
|
||||
|
||||
By using `Func`, we've guaranteed that even hugely deep $n$s won't ever look more complicated than this.
|
||||
|
||||
## `max depth`: "Realization"
|
||||
So, we've composed a bunch of functions of functions of ...
|
||||
We've also tracked their arguments, either manually (as above), or with the help of a handy `ParamsFlow` object.
|
||||
|
||||
But it'd be pointless to just compose away forever.
|
||||
We do actually need the data that they claim to compute now:
|
||||
|
||||
```python
|
||||
# A_all and KW_all must be tracked on the side.
|
||||
output_n = lazy_func_n.func(*A_all, **KW_all)
|
||||
```
|
||||
|
||||
Of course, this comes with enormous overhead.
|
||||
Aside from the function calls themselves (which can be non-trivial), we must also contend with the enormous inefficiency of performing array operations sequentially.
|
||||
|
||||
That brings us to the killer feature of `FuncFlow`, and the motivating reason for doing any of this at all:
|
||||
|
||||
```python
|
||||
output_n = lazy_func_n.func_jax(*A_all, **KW_all)
|
||||
```
|
||||
|
||||
What happened was, **the entire pipeline** was compiled, optimized, and computed with bare-metal performance on either a CPU, GPU, or TPU.
|
||||
With the help of the `jax` library (and its underlying OpenXLA bytecode), all of that inefficiency has been optimized based on _what we're trying to do_, not _exactly how we're doing it_, in order to maximize the use of modern massively-parallel devices.
|
||||
|
||||
See the documentation of `Func.func_jax()` for more information on this process.
|
||||
|
||||
|
||||
|
||||
# Lazy Math: Practical Considerations
|
||||
By using nodes to express a lazily-composed chain of mathematical operations on tensor-like data, we strike a difficult balance between UX, flexibility, and performance.
|
||||
|
||||
## UX
|
||||
UX is often more a matter of art/taste than science, so don't trust these philosophies too much - a lot of the analysis is entirely personal and subjective.
|
||||
|
||||
The goal for our UX is to minimize the "frictions" that cause cascading, small-scale _user anxiety_.
|
||||
|
||||
Especially of concern in a visual math system on large data volumes is **UX latency** - also known as **lag**.
|
||||
In particular, the most important facet to minimize is _emotional burden_ rather than quantitative milliseconds.
|
||||
Any repeated moment-to-moment friction can be very damaging to a user's ability to be productive in a piece of software.
|
||||
|
||||
Unfortunately, in a node-based architecture, data must generally be computed one step at a time, whenever any part of it is needed, and it must do so before any feedback can be provided.
|
||||
In a math system like this, that data is presumed "big", and as such we're left with the unfortunate experience of even the most well-cached, high-performance operations causing _just about anything_ to **feel** like a highly unpleasant slog as soon as the data gets big enough.
|
||||
**This discomfort scales with the size of data**, by the way, which might just cause users to never even attempt working with the data volume that they actually need.
|
||||
|
||||
For electrodynamic field analysis, it's not uncommon for toy examples to expend hundreds of megabytes of memory, all of which needs all manner of interesting things done to it.
|
||||
It can therefore be very easy to stumble across that feeling of "slogging through" any program that does real-world EM field analysis.
|
||||
This has consequences: The user tries fewer ideas, becomes more easily frustrated, and might ultimately accomplish less.
|
||||
|
||||
Lazy evaluation allows _delaying_ a computation to a point in time where the user both expects and understands the time that the computation takes.
|
||||
For example, the user experience of pressing a button clearly marked with terminology like "load", "save", "compute", "run", seems to be paired to a greatly increased emotional tolerance towards the latency introduced by pressing that button (so long as it is only clickable when it works).
|
||||
To a lesser degree, attaching a node link also seems to have this property, though that tolerance seems to fall as proficiency with the node-based tool rises.
|
||||
As a more nuanced example, when lag occurs due to the computing an image-based plot based on live-computed math, then the visual feedback of _the plot actually changing_ seems to have a similar effect, not least because it's emotionally well-understood that detaching the `Viewer` node would also remove the lag.
|
||||
|
||||
In short: Even if lazy evaluation didn't make any math faster, it will still _feel_ faster (to a point - raw performance obviously still matters).
|
||||
Without `FuncFlow`, the point of evaluation cannot be chosen at all, which is a huge issue for all the named reasons.
|
||||
With `FuncFlow`, better-chosen evaluation points can be chosen to cause the _user experience_ of high performance, simply because we were able to shift the exact same computation to a point in time where the user either understands or tolerates the delay better.
|
||||
|
||||
## Flexibility
|
||||
Large-scale math is done on tensors, whether one knows (or likes!) it or not.
|
||||
To this end, the indexed arrays produced by `FuncFlow.func_jax` aren't quite sufficient for most operations we want to do:
|
||||
|
||||
- **Naming**: What _is_ each axis?
|
||||
Unnamed index axes are sometimes easy to decode, but in general, names have an unexpectedly critical function when operating on arrays.
|
||||
Lack of names is a huge part of why perfectly elegant array math in ex. `MATLAB` or `numpy` can so easily feel so incredibly convoluted.
|
||||
_Sometimes arrays with named axes are called "structured arrays".
|
||||
|
||||
- **Coordinates**: What do the indices of each axis really _mean_?
|
||||
For example, an array of $500$ by-wavelength observations of power (watts) can't be limited to between $200nm$ to $700nm$.
|
||||
But they can be limited to between index `23` to `298`.
|
||||
I'm **just supposed to know** that `23` means $200nm$, and that `298` indicates the observation just after $700nm$, and _hope_ that this is exact enough.
|
||||
|
||||
Not only do we endeavor to track these, but we also introduce unit-awareness to the coordinates, and design the entire math system to visually communicate the state of arrays before/after every single computation, as well as only expose operations that this tracked data indicates possible.
|
||||
|
||||
In practice, this happens in `FlowKind.Info`, which due to having its own `FlowKind` "lane" can be adjusted without triggering changes to (and therefore recompilation of) the `FlowKind.Func` chain.
|
||||
**Please consult the `InfoFlow` documentation for more**.
|
||||
|
||||
## Performance
|
||||
All values introduced while processing are kept in a seperate `FlowKind` lane, with its own incremental caching: `FlowKind.Params`.
|
||||
|
||||
It's a simple mechanism, but for the cost of introducing an extra `FlowKind` "lane", all of the values used to process data can be live-adjusted without the overhead of recompiling the entire `Func` every time anything changes.
|
||||
Moreover, values used to process data don't even have to be numbers yet: They can be expressions of symbolic variables, complete with units, which are only realized at the very end of the chain, by the node that absolutely cannot function without the actual numerical data.
|
||||
|
||||
See the `ParamFlow` documentation for more information.
|
||||
|
||||
|
||||
|
||||
# Conclusion
|
||||
There is, of course, a lot more to say about the math system in general.
|
||||
A few teasers of what nodes can do with this system:
|
||||
|
||||
**Auto-Differentiation**: `jax.jit` isn't even really the killer feature of `jax`.
|
||||
`jax` can automatically differentiate `FuncFlow.func_jax` with respect to any input parameter, including for fwd/bck jacobians/hessians, with robust numerical stability.
|
||||
When used in
|
||||
**Symbolic Interop**: Any `sympy` expression containing symbolic variables can be compiled, by `sympy`, into a `jax`-compatible function which takes
|
||||
We make use of this in the `Expr` socket, enabling true symbolic math to be used in high-performance lazy `jax` computations.
|
||||
**Tidy3D Interop**: For some parameters of some simulation objects, `tidy3d` actually supports adjoint-driven differentiation _through the cloud simulation_.
|
||||
This enables our humble interface to implement fully functional **inverse design** of parameterized structures, using only nodes.
|
||||
|
||||
But above all, we hope that this math system is fun, practical, and maybe even interesting.
|
||||
|
||||
Attributes:
|
||||
func: The function that generates the represented value.
|
||||
func_args: The constrained identity of all positional arguments to the function.
|
||||
func_kwargs: The constrained identity of all keyword arguments to the function.
|
||||
supports_jax: Whether `self.func` can be compiled with JAX's JIT compiler.
|
||||
See the documentation of `self.func_jax()`.
|
||||
"""
|
||||
|
||||
func: LazyFunction
|
||||
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
## TODO: Use SimSymbol instead of the MathType|PT union.
|
||||
## -- SimSymbol is an ideal pivot point for both, as well as valid domains.
|
||||
## -- SimSymbol has more semantic meaning, including a name.
|
||||
## -- If desired, SimSymbols could maybe even require a specific unit.
|
||||
## It could greatly simplify a whole lot of pain associated with func_args.
|
||||
supports_jax: bool = False
|
||||
|
||||
####################
|
||||
# - Functions
|
||||
####################
|
||||
@functools.cached_property
|
||||
def func_jax(self) -> LazyFunction:
|
||||
"""Compile `self.func` into optimized XLA bytecode using `jax.jit`.
|
||||
|
||||
Not all functions can be compiled like this by `jax`.
|
||||
A few critical criteria include:
|
||||
|
||||
- **Only JAX Ops**: All operations performed within the function must be explicitly compatible with `jax`, which generally means only using functions in `jax.lax`, `jax.numpy`
|
||||
- **Known Shape**: The exact dimensions of the output, and of the inputs, must be known at `jit`-time.
|
||||
|
||||
In return, one receives:
|
||||
|
||||
- **Automatic Differentiation**: `jax` can robustly differentiate this function with respect to _any_ parameter.
|
||||
This includes Jacobians and Hessians, forwards and backwards, real or complex, all with good numerical stability.
|
||||
Since `tidy3d`'s simulator registers itself as `jax`-differentiable (using the adjoint method), this "autodiff" support can extend all the way from parameters in the simulation definition, to gradients of the simulation output.
|
||||
When using these gradients for optimization, one achieves what is called "inverse design", where the desired properties of the output fields are used to automatically select simulation input parameters.
|
||||
|
||||
- **Performance**: XLA is a cross-industry project with the singular goal of providing a high-performance compilation target for data-driven programs.
|
||||
Published architects of OpenXLA include Alibaba, Amazon Web Services, AMD, Apple, Arm, Google, Intel, Meta, and NVIDIA.
|
||||
|
||||
- **Device Agnosticism**: XLA bytecode runs not just on CPUs, but on massively parallel devices like GPUs and TPUs as well.
|
||||
This enables massive speedups, and greatly expands the amount of data that is practical to work with at one time.
|
||||
|
||||
|
||||
Notes:
|
||||
The property `self.supports_jax` manually tracks whether these criteria are satisfied.
|
||||
|
||||
**As much as possible**, the _entirety of `blender_maxwell`_ is designed to maximize the ability to set `self.supports_jax = True` as often as possible.
|
||||
|
||||
**However**, there are many cases where a lazily-evaluated value is desirable, but `jax` isn't supported.
|
||||
These include design space exploration, where any particular parameter might vary for the purpose of producing batched simulations.
|
||||
In these cases, trying to compile a `self.func_jax` will raise a `ValueError`.
|
||||
|
||||
Returns:
|
||||
The `jit`-compiled function, ready to run on CPU, GPU, or XLA.
|
||||
|
||||
Raises:
|
||||
ValueError: If `self.supports_jax` is `False`.
|
||||
|
||||
References:
|
||||
JAX JIT: <https://jax.readthedocs.io/en/latest/jit-compilation.html>
|
||||
OpenXLA: <https://openxla.org/xla>
|
||||
"""
|
||||
if self.supports_jax:
|
||||
return jax.jit(self.func)
|
||||
|
||||
msg = 'Can\'t express FuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Realization
|
||||
####################
|
||||
def realize(
|
||||
self,
|
||||
params: ParamsFlow,
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||
) -> typ.Self:
|
||||
if self.supports_jax:
|
||||
return self.func_jax(
|
||||
*params.scaled_func_args(unit_system, symbol_values),
|
||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
||||
)
|
||||
return self.func(
|
||||
*params.scaled_func_args(unit_system, symbol_values),
|
||||
*params.scaled_func_kwargs(unit_system, symbol_values),
|
||||
)
|
||||
|
||||
####################
|
||||
# - Composition Operations
|
||||
####################
|
||||
def compose_within(
|
||||
self,
|
||||
enclosing_func: LazyFunction,
|
||||
enclosing_func_args: list[type] = (),
|
||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||
supports_jax: bool = False,
|
||||
) -> typ.Self:
|
||||
"""Compose `self.func` within the given enclosing function, which itself takes arguments, and create a new `FuncFlow` to contain it.
|
||||
|
||||
This is the fundamental operation used to "chain" functions together.
|
||||
|
||||
Examples:
|
||||
Consider a simple composition based on two expressions:
|
||||
```python
|
||||
R = spux.MathType.Real
|
||||
C = spux.MathType.Complex
|
||||
x, y = sp.symbols('x y', real=True)
|
||||
|
||||
# Prepare "Root" FuncFlow w/x,y args
|
||||
expr_root = 3*x + y**2 - 100
|
||||
expr_root_func = sp.lambdify([x, y], expr, 'jax')
|
||||
|
||||
func_root = FuncFlow(func=expr_root_func, func_args=[R,R], supports_jax=True)
|
||||
|
||||
# Compose "Enclosing" FuncFlow w/z arg
|
||||
r = sp.Symbol('z', real=True)
|
||||
z = sp.Symbol('z', complex=True)
|
||||
expr = 10*sp.re(z) / (z + r)
|
||||
expr_func = sp.lambdify([r, z], expr, 'jax')
|
||||
|
||||
func = func_root.compose_within(enclosing_func=expr_func, enclosing_func_args=[C])
|
||||
|
||||
# Compute 'expr_func(expr_root_func(10.0, -500.0), 1+8j)'
|
||||
f.func_jax(10.0, -500.0, 1+8j)
|
||||
```
|
||||
|
||||
Using this function, it's easy to "keep adding" symbolic functions of any kind to the chain, without introducing extraneous complexity or compromising the ease of calling the final function.
|
||||
|
||||
Returns:
|
||||
A lazy function that takes both the enclosed and enclosing arguments, and returns the value of the enclosing function (whose first argument is the output value of the enclosed function).
|
||||
"""
|
||||
return FuncFlow(
|
||||
func=lambda *args, **kwargs: enclosing_func(
|
||||
self.func(
|
||||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
),
|
||||
*args[len(self.func_args) :],
|
||||
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
|
||||
),
|
||||
func_args=self.func_args + list(enclosing_func_args),
|
||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||
supports_jax=self.supports_jax and supports_jax,
|
||||
)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: typ.Self,
|
||||
) -> typ.Self:
|
||||
"""Create a lazy function that takes all arguments of both lazy-function inputs, and itself promises to return a 2-tuple containing the outputs of both inputs.
|
||||
|
||||
Generally, `self.func` produces a single array as output (when doing math, at least).
|
||||
But sometimes (as in the `OperateMathNode`), we need to perform a binary operation between two arrays, like say, $+$.
|
||||
Without realizing both `FuncFlow`s, it's not immediately obvious how one might accomplish this.
|
||||
|
||||
This overloaded function of the `|` operator (used as `left | right`) solves that problem.
|
||||
A new `FuncFlow` is created, which takes the arguments of both inputs, and which produces a single output value: A 2-tuple, where each element if the output of each function.
|
||||
|
||||
Examples:
|
||||
Consider this illustrative (pseudocode) example:
|
||||
```python
|
||||
# Presume a,b are values, and that A,B are their identifiers.
|
||||
func_1 = FuncFlow(func=compute_big_data_1, func_args=[A])
|
||||
func_2 = FuncFlow(func=compute_big_data_2, func_args=[B])
|
||||
|
||||
f = (func_1 | func_2).compose_within(func=lambda D: D[0] + D[1])
|
||||
|
||||
f.func(a, b) ## Computes big_data_1 + big_data_2 @A=a, B=b
|
||||
```
|
||||
|
||||
Because of `__or__` (the operator `|`), the difficult and non-obvious task of adding the outputs of these unrealized functions because quite simple.
|
||||
|
||||
Notes:
|
||||
**Order matters**.
|
||||
`self` will be available in the new function's output as index `0`, while `other` will be available as index `1`.
|
||||
|
||||
As with anything lazy-composition-y, it can seem a bit strange at first.
|
||||
When reading the source code, pay special attention to the way that `args` is sliced to segment the positional arguments.
|
||||
|
||||
Returns:
|
||||
A lazy function that takes all arguments of both inputs, and returns a 2-tuple containing both output arguments.
|
||||
"""
|
||||
return FuncFlow(
|
||||
func=lambda *args, **kwargs: (
|
||||
self.func(
|
||||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
),
|
||||
other.func(
|
||||
*list(args[len(self.func_args) :]),
|
||||
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
|
||||
),
|
||||
),
|
||||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
supports_jax=self.supports_jax and other.supports_jax,
|
||||
)
|
|
@ -28,13 +28,20 @@ from blender_maxwell.utils import extra_sympy_units as spux
|
|||
from blender_maxwell.utils import logger
|
||||
|
||||
from .array import ArrayFlow
|
||||
from .flow_kinds import FlowKind
|
||||
from .lazy_value_func import LazyValueFuncFlow
|
||||
from .lazy_func import FuncFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class ScalingMode(enum.StrEnum):
|
||||
"""Identifier for how to space steps between two boundaries.
|
||||
|
||||
Attributes:
|
||||
Lin: Uniform spacing between two endpoints.
|
||||
Geom: Log spacing between two endpoints, given as values.
|
||||
Log: Log spacing between two endpoints, given as powers of a common base.
|
||||
"""
|
||||
|
||||
Lin = enum.auto()
|
||||
Geom = enum.auto()
|
||||
Log = enum.auto()
|
||||
|
@ -54,37 +61,21 @@ class ScalingMode(enum.StrEnum):
|
|||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class LazyArrayRangeFlow:
|
||||
r"""Represents a linearly/logarithmically spaced array using symbolic boundary expressions, with support for units and lazy evaluation.
|
||||
class RangeFlow:
|
||||
r"""Represents a spaced array using symbolic boundary expressions.
|
||||
|
||||
# Advantages
|
||||
Whenever an array can be represented like this, the advantages over an `ArrayFlow` are numerous.
|
||||
|
||||
## Memory
|
||||
# Memory Scaling
|
||||
`ArrayFlow` generally has a memory scaling of $O(n)$.
|
||||
Naturally, `LazyArrayRangeFlow` is always constant, since only the boundaries and steps are stored.
|
||||
Naturally, `RangeFlow` is always constant, since only the boundaries and steps are stored.
|
||||
|
||||
## Symbolic
|
||||
Both boundary points are symbolic expressions, within which pre-defined `sp.Symbol`s can participate in a constrained manner (ex. an integer symbol).
|
||||
# Symbolic Bounds
|
||||
`self.start` and `self.stop` boundary points are symbolic expressions, within which any element of `self.symbols` can participate.
|
||||
|
||||
One need not know the value of the symbols immediately - such decisions can be deferred until later in the computational flow.
|
||||
**It is the user's responsibility** to ensure that `self.start < self.stop`.
|
||||
|
||||
## Performant Unit-Aware Operations
|
||||
While `ArrayFlow`s are also unit-aware, the time-cost of _any_ unit-scaling operation scales with $O(n)$.
|
||||
`LazyArrayRangeFlow`, by contrast, scales as $O(1)$.
|
||||
|
||||
As a result, more complicated operations (like symbolic or unit-based) that might be difficult to perform interactively in real-time on an `ArrayFlow` will work perfectly with this object, even with added complexity
|
||||
|
||||
## High-Performance Composition and Gradiant
|
||||
With `self.as_func`, a `jax` function is produced that generates the array according to the symbolic `start`, `stop` and `steps`.
|
||||
There are two nice things about this:
|
||||
|
||||
- **Gradient**: The gradient of the output array, with respect to any symbols used to define the input bounds, can easily be found using `jax.grad` over `self.as_func`.
|
||||
- **JIT**: When `self.as_func` is composed with other `jax` functions, and `jax.jit` is run to optimize the entire thing, the "cost of array generation" _will often be optimized away significantly or entirely_.
|
||||
|
||||
Thus, as part of larger computations, the performance properties of `LazyArrayRangeFlow` is extremely favorable.
|
||||
|
||||
## Numerical Properties
|
||||
# Numerical Properties
|
||||
Since the bounds support exact (ex. rational) calculations and symbolic manipulations (_by virtue of being symbolic expressions_), the opportunities for certain kinds of numerical instability are mitigated.
|
||||
|
||||
Attributes:
|
||||
|
@ -108,8 +99,11 @@ class LazyArrayRangeFlow:
|
|||
|
||||
unit: spux.Unit | None = None
|
||||
|
||||
symbols: frozenset[spux.IntSymbol] = frozenset()
|
||||
symbols: frozenset[spux.Symbol] = frozenset()
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@functools.cached_property
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
"""Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name.
|
||||
|
@ -121,6 +115,19 @@ class LazyArrayRangeFlow:
|
|||
"""
|
||||
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||
|
||||
@property
|
||||
def is_symbolic(self) -> bool:
|
||||
"""Whether the `RangeFlow` has unrealized symbols."""
|
||||
return len(self.symbols) > 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Compute the length of the array that would be realized.
|
||||
|
||||
Returns:
|
||||
The number of steps.
|
||||
"""
|
||||
return self.steps
|
||||
|
||||
@functools.cached_property
|
||||
def mathtype(self) -> spux.MathType:
|
||||
"""Conservatively compute the most stringent `spux.MathType` that can represent both `self.start` and `self.stop`.
|
||||
|
@ -156,128 +163,28 @@ class LazyArrayRangeFlow:
|
|||
)
|
||||
return combined_mathtype
|
||||
|
||||
def __len__(self):
|
||||
"""Compute the length of the array to be realized.
|
||||
|
||||
Returns:
|
||||
The number of steps.
|
||||
"""
|
||||
return self.steps
|
||||
|
||||
####################
|
||||
# - Units
|
||||
####################
|
||||
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
|
||||
"""Replaces the unit without rescaling the unitless bounds.
|
||||
|
||||
Parameters:
|
||||
corrected_unit: The unit to replace the current unit with.
|
||||
|
||||
Returns:
|
||||
A new `LazyArrayRangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Corrected unit to %s',
|
||||
self,
|
||||
corrected_unit,
|
||||
)
|
||||
return LazyArrayRangeFlow(
|
||||
start=self.start,
|
||||
stop=self.stop,
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=corrected_unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
||||
"""Replaces the unit, **with** rescaling of the bounds.
|
||||
|
||||
Parameters:
|
||||
unit: The unit to convert the bounds to.
|
||||
|
||||
Returns:
|
||||
A new `LazyArrayRangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to unit %s',
|
||||
self,
|
||||
unit,
|
||||
)
|
||||
return LazyArrayRangeFlow(
|
||||
start=spux.scale_to_unit(self.start * self.unit, unit),
|
||||
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||
"""Replaces the units, **with** rescaling of the bounds.
|
||||
|
||||
Parameters:
|
||||
unit: The unit to convert the bounds to.
|
||||
|
||||
Returns:
|
||||
A new `LazyArrayRangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to new unit system (new unit = %s)',
|
||||
self,
|
||||
unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
)
|
||||
return LazyArrayRangeFlow(
|
||||
start=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
stop=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
####################
|
||||
# - Bound Operations
|
||||
# - Methods
|
||||
####################
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
"""Apply an order-preserving function to each bound, then (optionally) transform the result w/new unit and/or order.
|
||||
|
||||
An optimized expression will be built and applied to `self.values` using `sympy.lambdify()`.
|
||||
|
||||
Parameters:
|
||||
rescale_func: An **order-preserving** function to apply to each array element.
|
||||
reverse: Whether to reverse the order of the result.
|
||||
new_unit: An (optional) new unit to scale the result to.
|
||||
"""
|
||||
new_pre_start = self.start if not reverse else self.stop
|
||||
new_pre_stop = self.stop if not reverse else self.start
|
||||
|
||||
new_start = rescale_func(new_pre_start * self.unit)
|
||||
new_stop = rescale_func(new_pre_stop * self.unit)
|
||||
|
||||
return LazyArrayRangeFlow(
|
||||
return RangeFlow(
|
||||
start=(
|
||||
spux.scale_to_unit(new_start, new_unit)
|
||||
if new_unit is not None
|
||||
|
@ -294,39 +201,11 @@ class LazyArrayRangeFlow:
|
|||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
def rescale_bounds(
|
||||
self,
|
||||
rescale_func: typ.Callable[
|
||||
[spux.ScalarUnitlessComplexExpr], spux.ScalarUnitlessComplexExpr
|
||||
],
|
||||
reverse: bool = False,
|
||||
) -> typ.Self:
|
||||
"""Apply a function to the bounds, effectively rescaling the represented array.
|
||||
|
||||
Notes:
|
||||
**It is presumed that the bounds are scaled with the same factor**.
|
||||
Breaking this presumption may have unexpected results.
|
||||
|
||||
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
|
||||
|
||||
Parameters:
|
||||
scaler: The function that scales each bound.
|
||||
reverse: Whether to reverse the bounds after running the `scaler`.
|
||||
|
||||
Returns:
|
||||
A rescaled `LazyArrayRangeFlow`.
|
||||
"""
|
||||
return LazyArrayRangeFlow(
|
||||
start=rescale_func(self.start if not reverse else self.stop),
|
||||
stop=rescale_func(self.stop if not reverse else self.start),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=self.unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
####################
|
||||
# - Lazy Representation
|
||||
# - Exporters
|
||||
####################
|
||||
@functools.cached_property
|
||||
def array_generator(
|
||||
|
@ -345,10 +224,10 @@ class LazyArrayRangeFlow:
|
|||
ScalingMode.Geom: jnp.geomspace,
|
||||
ScalingMode.Log: jnp.logspace,
|
||||
}.get(self.scaling)
|
||||
|
||||
if jnp_nspace is None:
|
||||
msg = f'ArrayFlow scaling method {self.scaling} is unsupported'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return jnp_nspace
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -361,12 +240,12 @@ class LazyArrayRangeFlow:
|
|||
The ordering of the symbols is identical to `self.symbols`, which is guaranteed to be a deterministically sorted list of symbols.
|
||||
|
||||
Returns:
|
||||
A `LazyValueFuncFlow` that, given the input symbols defined in `self.symbols`,
|
||||
A `FuncFlow` that, given the input symbols defined in `self.symbols`,
|
||||
"""
|
||||
# Compile JAX Functions for Start/End Expressions
|
||||
## FYI, JAX-in-JAX works perfectly fine.
|
||||
start_jax = sp.lambdify(self.symbols, self.start, 'jax')
|
||||
stop_jax = sp.lambdify(self.symbols, self.stop, 'jax')
|
||||
## -> FYI, JAX-in-JAX works perfectly fine.
|
||||
start_jax = sp.lambdify(self.sorted_symbols, self.start, 'jax')
|
||||
stop_jax = sp.lambdify(self.sorted_symbols, self.stop, 'jax')
|
||||
|
||||
# Compile ArrayGen Function
|
||||
def gen_array(
|
||||
|
@ -378,18 +257,18 @@ class LazyArrayRangeFlow:
|
|||
return gen_array
|
||||
|
||||
@functools.cached_property
|
||||
def as_lazy_value_func(self) -> LazyValueFuncFlow:
|
||||
"""Creates a `LazyValueFuncFlow` using the output of `self.as_func`.
|
||||
def as_lazy_func(self) -> FuncFlow:
|
||||
"""Creates a `FuncFlow` using the output of `self.as_func`.
|
||||
|
||||
This is useful for ex. parameterizing the first array in the node graph, without binding an entire computed array.
|
||||
|
||||
Notes:
|
||||
The the function enclosed in the `LazyValueFuncFlow` is identical to the one returned by `self.as_func`.
|
||||
The the function enclosed in the `FuncFlow` is identical to the one returned by `self.as_func`.
|
||||
|
||||
Returns:
|
||||
A `LazyValueFuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
|
||||
A `FuncFlow` containing `self.as_func`, as well as appropriate supporting settings.
|
||||
"""
|
||||
return LazyValueFuncFlow(
|
||||
return FuncFlow(
|
||||
func=self.as_func,
|
||||
func_args=[(spux.MathType.from_expr(sym)) for sym in self.symbols],
|
||||
supports_jax=True,
|
||||
|
@ -401,7 +280,8 @@ class LazyArrayRangeFlow:
|
|||
def realize_start(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
) -> ArrayFlow | LazyValueFuncFlow:
|
||||
) -> int | float | complex:
|
||||
"""Realize the start-bound by inserting particular values for each symbol."""
|
||||
return spux.sympy_to_python(
|
||||
self.start.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||
)
|
||||
|
@ -409,7 +289,8 @@ class LazyArrayRangeFlow:
|
|||
def realize_stop(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
) -> ArrayFlow | LazyValueFuncFlow:
|
||||
) -> int | float | complex:
|
||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||
return spux.sympy_to_python(
|
||||
self.stop.subs({sym: symbol_values[sym.name] for sym in self.symbols})
|
||||
)
|
||||
|
@ -417,7 +298,11 @@ class LazyArrayRangeFlow:
|
|||
def realize_step_size(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
) -> ArrayFlow | LazyValueFuncFlow:
|
||||
) -> int | float | complex:
|
||||
"""Realize the stop-bound by inserting particular values for each symbol."""
|
||||
if self.scaling is not ScalingMode.Lin:
|
||||
raise NotImplementedError('Non-linear scaling mode not yet suported')
|
||||
|
||||
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
|
||||
|
||||
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
|
||||
|
@ -427,48 +312,34 @@ class LazyArrayRangeFlow:
|
|||
def realize(
|
||||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
kind: typ.Literal[FlowKind.Array, FlowKind.LazyValueFunc] = FlowKind.Array,
|
||||
) -> ArrayFlow | LazyValueFuncFlow:
|
||||
"""Apply a function to the bounds, effectively rescaling the represented array.
|
||||
|
||||
Notes:
|
||||
**It is presumed that the bounds are scaled with the same factor**.
|
||||
Breaking this presumption may have unexpected results.
|
||||
|
||||
The scalar, unitless, complex-valuedness of the bounds must also be respected; additionally, new symbols must not be introduced.
|
||||
) -> ArrayFlow:
|
||||
"""Realize the array represented by this `RangeFlow` by realizing each bound, then generating all intermediate values as an array.
|
||||
|
||||
Parameters:
|
||||
scaler: The function that scales each bound.
|
||||
reverse: Whether to reverse the bounds after running the `scaler`.
|
||||
symbol_values: The particular values for each symbol, which will be inserted into the expression of each bound to realize them.
|
||||
|
||||
Returns:
|
||||
A rescaled `LazyArrayRangeFlow`.
|
||||
An `ArrayFlow` containing this realized `RangeFlow`.
|
||||
"""
|
||||
if not set(self.symbols).issubset(set(symbol_values.keys())):
|
||||
msg = f'Provided symbols ({set(symbol_values.keys())}) do not provide values for all expression symbols ({self.symbols}) that may be found in the boundary expressions (start={self.start}, end={self.end})'
|
||||
raise ValueError(msg)
|
||||
## TODO: Check symbol values for coverage.
|
||||
|
||||
# Realize Symbols
|
||||
realized_start = self.realize_start(symbol_values)
|
||||
realized_stop = self.realize_stop(symbol_values)
|
||||
|
||||
# Return Linspace / Logspace
|
||||
def gen_array() -> jtyp.Inexact[jtyp.Array, ' steps']:
|
||||
return self.array_generator(realized_start, realized_stop, self.steps)
|
||||
|
||||
if kind == FlowKind.Array:
|
||||
return ArrayFlow(values=gen_array(), unit=self.unit, is_sorted=True)
|
||||
if kind == FlowKind.LazyValueFunc:
|
||||
return LazyValueFuncFlow(func=gen_array, supports_jax=True)
|
||||
|
||||
msg = f'Invalid kind: {kind}'
|
||||
raise TypeError(msg)
|
||||
return ArrayFlow(
|
||||
values=self.as_func(*[symbol_values[sym] for sym in self.sorted_symbols]),
|
||||
unit=self.unit,
|
||||
is_sorted=True,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def realize_array(self) -> ArrayFlow:
|
||||
"""Standardized access to `self.realize()` when there are no symbols."""
|
||||
return self.realize()
|
||||
|
||||
def __getitem__(self, subscript: slice):
|
||||
"""Implement indexing and slicing in a sane way.
|
||||
|
||||
- **Integer Index**: Not yet implemented.
|
||||
- **Slice**: Return the `RangeFlow` that creates the same `ArrayFlow` as would be created by computing `self.realize_array`, then slicing that.
|
||||
"""
|
||||
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
|
||||
# Parse Slice
|
||||
start = subscript.start if subscript.start is not None else 0
|
||||
|
@ -482,7 +353,7 @@ class LazyArrayRangeFlow:
|
|||
new_start = step_size * start
|
||||
new_stop = new_start + step_size * slice_steps
|
||||
|
||||
return LazyArrayRangeFlow(
|
||||
return RangeFlow(
|
||||
start=sp.S(new_start),
|
||||
stop=sp.S(new_stop),
|
||||
steps=slice_steps,
|
||||
|
@ -492,3 +363,104 @@ class LazyArrayRangeFlow:
|
|||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
####################
|
||||
# - Units
|
||||
####################
|
||||
def correct_unit(self, corrected_unit: spux.Unit) -> typ.Self:
|
||||
"""Replaces the unit without rescaling the unitless bounds.
|
||||
|
||||
Parameters:
|
||||
corrected_unit: The unit to replace the current unit with.
|
||||
|
||||
Returns:
|
||||
A new `RangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Corrected unit to %s',
|
||||
self,
|
||||
corrected_unit,
|
||||
)
|
||||
return RangeFlow(
|
||||
start=self.start,
|
||||
stop=self.stop,
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=corrected_unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to correct unit of unitless LazyDataValueRange "{corrected_unit}"'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit(self, unit: spux.Unit) -> typ.Self:
|
||||
"""Replaces the unit, **with** rescaling of the bounds.
|
||||
|
||||
Parameters:
|
||||
unit: The unit to convert the bounds to.
|
||||
|
||||
Returns:
|
||||
A new `RangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to unit %s',
|
||||
self,
|
||||
unit,
|
||||
)
|
||||
return RangeFlow(
|
||||
start=spux.scale_to_unit(self.start * self.unit, unit),
|
||||
stop=spux.scale_to_unit(self.stop * self.unit, unit),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale_to_unit_system(self, unit_system: spux.Unit) -> typ.Self:
|
||||
"""Replaces the units, **with** rescaling of the bounds.
|
||||
|
||||
Parameters:
|
||||
unit: The unit to convert the bounds to.
|
||||
|
||||
Returns:
|
||||
A new `RangeFlow` with replaced unit.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing unit is `None`, indicating that there is no unit to correct.
|
||||
"""
|
||||
if self.unit is not None:
|
||||
log.debug(
|
||||
'%s: Scaled to new unit system (new unit = %s)',
|
||||
self,
|
||||
unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
)
|
||||
return RangeFlow(
|
||||
start=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.start * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
stop=spux.strip_unit_system(
|
||||
spux.convert_to_unit_system(self.stop * self.unit, unit_system),
|
||||
unit_system,
|
||||
),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=unit_system[spux.PhysicalType.from_unit(self.unit)],
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f'Tried to rescale unitless LazyDataValueRange to unit system {unit_system}'
|
||||
)
|
||||
raise ValueError(msg)
|
|
@ -1,210 +0,0 @@
|
|||
# blender_maxwell
|
||||
# Copyright (C) 2024 blender_maxwell Project Contributors
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing as typ
|
||||
from types import MappingProxyType
|
||||
|
||||
import jax
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
LazyFunction: typ.TypeAlias = typ.Callable[[typ.Any, ...], typ.Any]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class LazyValueFuncFlow:
|
||||
r"""Wraps a composable function, providing useful information and operations.
|
||||
|
||||
# Data Flow as Function Composition
|
||||
When using nodes to do math, it can be a good idea to express a **flow of data as the composition of functions**.
|
||||
|
||||
Each node creates a new function, which uses the still-unknown (aka. **lazy**) output of the previous function to plan some calculations.
|
||||
Some new arguments may also be added, of course.
|
||||
|
||||
## Root Function
|
||||
Of course, one needs to select a "bottom" function, which has no previous function as input.
|
||||
Thus, the first step is to define this **root function**:
|
||||
|
||||
$$
|
||||
f_0:\ \ \ \ \biggl(
|
||||
\underbrace{a_1, a_2, ..., a_p}_{\texttt{args}},\
|
||||
\underbrace{
|
||||
\begin{bmatrix} k_1 \\ v_1\end{bmatrix},
|
||||
\begin{bmatrix} k_2 \\ v_2\end{bmatrix},
|
||||
...,
|
||||
\begin{bmatrix} k_q \\ v_q\end{bmatrix}
|
||||
}_{\texttt{kwargs}}
|
||||
\biggr) \to \text{output}_0
|
||||
$$
|
||||
|
||||
We'll express this simple snippet like so:
|
||||
|
||||
```python
|
||||
# Presume 'A0', 'KV0' contain only the args/kwargs for f_0
|
||||
## 'A0', 'KV0' are of length 'p' and 'q'
|
||||
def f_0(*args, **kwargs): ...
|
||||
|
||||
lazy_value_func_0 = LazyValueFuncFlow(
|
||||
func=f_0,
|
||||
func_args=[(a_i, type(a_i)) for a_i in A0],
|
||||
func_kwargs={k: v for k,v in KV0},
|
||||
)
|
||||
output_0 = lazy_value_func.func(*A0_computed, **KV0_computed)
|
||||
```
|
||||
|
||||
So far so good.
|
||||
But of course, nothing interesting has really happened yet.
|
||||
|
||||
## Composing Functions
|
||||
The key thing is the next step: The function that uses the result of $f_0$!
|
||||
|
||||
$$
|
||||
f_1:\ \ \ \ \biggl(
|
||||
f_0(...),\ \
|
||||
\underbrace{\{a_i\}_p^{p+r}}_{\texttt{args[p:]}},\
|
||||
\underbrace{\biggl\{
|
||||
\begin{bmatrix} k_i \\ v_i\end{bmatrix}
|
||||
\biggr\}_q^{q+s}}_{\texttt{kwargs[p:]}}
|
||||
\biggr) \to \text{output}_1
|
||||
$$
|
||||
|
||||
Notice that _$f_1$ needs the arguments of both $f_0$ and $f_1$_.
|
||||
Tracking arguments is already getting out of hand; we already have to use `...` to keep it readeable!
|
||||
|
||||
But doing so with `LazyValueFunc` is not so complex:
|
||||
|
||||
```python
|
||||
# Presume 'A1', 'K1' contain only the args/kwarg names for f_1
|
||||
## 'A1', 'KV1' are therefore of length 'r' and 's'
|
||||
def f_1(output_0, *args, **kwargs): ...
|
||||
|
||||
lazy_value_func_1 = lazy_value_func_0.compose_within(
|
||||
enclosing_func=f_1,
|
||||
enclosing_func_args=[(a_i, type(a_i)) for a_i in A1],
|
||||
enclosing_func_kwargs={k: type(v) for k,v in K1},
|
||||
)
|
||||
|
||||
A_computed = A0_computed + A1_computed
|
||||
KW_computed = KV0_computed + KV1_computed
|
||||
output_1 = lazy_value_func_1.func(*A_computed, **KW_computed)
|
||||
```
|
||||
|
||||
We only need the arguments to $f_1$, and `LazyValueFunc` figures out how to make one function with enough arguments to call both.
|
||||
|
||||
## Isn't Laying Functions Slow/Hard?
|
||||
Imagine that each function represents the action of a node, each of which performs expensive calculations on huge `numpy` arrays (**as one does when processing electromagnetic field data**).
|
||||
At the end, a node might run the entire procedure with all arguments:
|
||||
|
||||
```python
|
||||
output_n = lazy_value_func_n.func(*A_all, **KW_all)
|
||||
```
|
||||
|
||||
It's rough: Most non-trivial pipelines drown in the time/memory overhead of incremental `numpy` operations - individually fast, but collectively iffy.
|
||||
|
||||
The killer feature of `LazyValueFuncFlow` is a sprinkle of black magic:
|
||||
|
||||
```python
|
||||
func_n_jax = lazy_value_func_n.func_jax
|
||||
output_n = func_n_jax(*A_all, **KW_all) ## Runs on your GPU
|
||||
```
|
||||
|
||||
What happened was, **the entire pipeline** was compiled and optimized for high performance on not just your CPU, _but also (possibly) your GPU_.
|
||||
All the layered function calls and inefficient incremental processing is **transformed into a high-performance program**.
|
||||
|
||||
Thank `jax` - specifically, `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), which internally enables this magic with a single function call.
|
||||
|
||||
## Other Considerations
|
||||
**Auto-Differentiation**: Incredibly, `jax.jit` isn't the killer feature of `jax`. The function that comes out of `LazyValueFuncFlow` can also be differentiated with `jax.grad` (read: high-performance Jacobians for optimizing input parameters).
|
||||
|
||||
Though designed for machine learning, there's no reason other fields can't enjoy their inventions!
|
||||
|
||||
**Impact of Independent Caching**: JIT'ing can be slow.
|
||||
That's why `LazyValueFuncFlow` has its own `FlowKind` "lane", which means that **only changes to the processing procedures will cause recompilation**.
|
||||
|
||||
Generally, adjustable values that affect the output will flow via the `Param` "lane", which has its own incremental caching, and only meets the compiled function when it's "plugged in" for final evaluation.
|
||||
The effect is a feeling of snappiness and interactivity, even as the volume of data grows.
|
||||
|
||||
Attributes:
|
||||
func: The function that the object encapsulates.
|
||||
bound_args: Arguments that will be packaged into function, which can't be later modifier.
|
||||
func_kwargs: Arguments to be specified by the user at the time of use.
|
||||
supports_jax: Whether the contained `self.function` can be compiled with JAX's JIT compiler.
|
||||
"""
|
||||
|
||||
func: LazyFunction
|
||||
func_args: list[spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
func_kwargs: dict[str, spux.MathType | spux.PhysicalType] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
supports_jax: bool = False
|
||||
|
||||
# Merging
|
||||
def __or__(
|
||||
self,
|
||||
other: typ.Self,
|
||||
):
|
||||
return LazyValueFuncFlow(
|
||||
func=lambda *args, **kwargs: (
|
||||
self.func(
|
||||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
),
|
||||
other.func(
|
||||
*list(args[len(self.func_args) :]),
|
||||
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
|
||||
),
|
||||
),
|
||||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
supports_jax=self.supports_jax and other.supports_jax,
|
||||
)
|
||||
|
||||
# Composition
|
||||
def compose_within(
|
||||
self,
|
||||
enclosing_func: LazyFunction,
|
||||
enclosing_func_args: list[type] = (),
|
||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||
supports_jax: bool = False,
|
||||
) -> typ.Self:
|
||||
return LazyValueFuncFlow(
|
||||
func=lambda *args, **kwargs: enclosing_func(
|
||||
self.func(
|
||||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
),
|
||||
*args[len(self.func_args) :],
|
||||
**{k: v for k, v in kwargs.items() if k not in self.func_kwargs},
|
||||
),
|
||||
func_args=self.func_args + list(enclosing_func_args),
|
||||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||
supports_jax=self.supports_jax and supports_jax,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def func_jax(self) -> LazyFunction:
|
||||
if self.supports_jax:
|
||||
return jax.jit(self.func)
|
||||
|
||||
msg = 'Can\'t express LazyValueFuncFlow as JAX function (using jax.jit), since "self.supports_jax" is False'
|
||||
raise ValueError(msg)
|
|
@ -22,7 +22,12 @@ from types import MappingProxyType
|
|||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import logger, sim_symbols
|
||||
|
||||
from .expr_info import ExprInfo
|
||||
from .flow_kinds import FlowKind
|
||||
|
||||
# from .info import InfoFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -32,7 +37,7 @@ class ParamsFlow:
|
|||
func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict)
|
||||
|
||||
symbols: frozenset[spux.Symbol] = frozenset()
|
||||
symbols: frozenset[sim_symbols.SimSymbol] = frozenset()
|
||||
|
||||
@functools.cached_property
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
|
@ -44,14 +49,22 @@ class ParamsFlow:
|
|||
return sorted(self.symbols, key=lambda sym: sym.name)
|
||||
|
||||
####################
|
||||
# - Scaled Func Args
|
||||
# - Realize Arguments
|
||||
####################
|
||||
def scaled_func_args(
|
||||
self,
|
||||
unit_system: spux.UnitSystem,
|
||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
):
|
||||
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
||||
"""Realize the function arguments contained in this `ParamsFlow`, making it ready for insertion into `Func.func()`.
|
||||
|
||||
For all `arg`s in `self.func_args`, the following operations are performed.
|
||||
|
||||
Notes:
|
||||
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
|
||||
"""
|
||||
if not all(sym in self.symbols for sym in symbol_values):
|
||||
msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}"
|
||||
raise ValueError(msg)
|
||||
|
@ -68,7 +81,7 @@ class ParamsFlow:
|
|||
|
||||
def scaled_func_kwargs(
|
||||
self,
|
||||
unit_system: spux.UnitSystem,
|
||||
unit_system: spux.UnitSystem | None = None,
|
||||
symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}),
|
||||
):
|
||||
"""Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments."""
|
||||
|
@ -92,7 +105,7 @@ class ParamsFlow:
|
|||
):
|
||||
"""Combine two function parameter lists, such that the LHS will be concatenated with the RHS.
|
||||
|
||||
Just like its neighbor in `LazyValueFunc`, this effectively combines two functions with unique parameters.
|
||||
Just like its neighbor in `Func`, this effectively combines two functions with unique parameters.
|
||||
The next composed function will receive a tuple of two arrays, instead of just one, allowing binary operations to occur.
|
||||
"""
|
||||
return ParamsFlow(
|
||||
|
@ -112,3 +125,61 @@ class ParamsFlow:
|
|||
func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs),
|
||||
symbols=self.symbols | enclosing_symbols,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Generate ExprSocketDef
|
||||
####################
|
||||
def sym_expr_infos(self, info, use_range: bool = False) -> dict[str, ExprInfo]:
|
||||
"""Generate all information needed to define expressions that realize all symbolic parameters in this `ParamsFlow`.
|
||||
|
||||
Many nodes need actual data, and as such, they require that the user select actual values for any symbols in the `ParamsFlow`.
|
||||
The best way to do this is to create one `ExprSocket` for each symbol that needs realizing.
|
||||
|
||||
Notes:
|
||||
This method is created for the purpose of being able to make this exact call in an `events.on_value_changed` method:
|
||||
```
|
||||
self.loose_input_sockets = {
|
||||
sym_name: sockets.ExprSocketDef(**expr_info)
|
||||
for sym_name, expr_info in params.sym_expr_infos(info).items()
|
||||
}
|
||||
```
|
||||
|
||||
Parameters:
|
||||
info: The InfoFlow associated with the `Expr` being realized.
|
||||
Each symbol in `self.symbols` **must** have an associated same-named dimension in `info`.
|
||||
use_range: Causes the
|
||||
|
||||
The `ExprInfo`s can be directly defererenced `**expr_info`)
|
||||
"""
|
||||
for sim_sym in self.sorted_symbols:
|
||||
if use_range and sim_sym.mathtype is spux.MathType.Complex:
|
||||
msg = 'No support for complex range in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
if use_range and (sim_sym.rows > 1 or sim_sym.cols > 1):
|
||||
msg = 'No support for non-scalar elements of range in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
if sim_sym.rows > 3 or sim_sym.cols > 1:
|
||||
msg = 'No support for >Vec3 / Matrix values in ExprInfo'
|
||||
raise NotImplementedError(msg)
|
||||
return {
|
||||
sim_sym.name: {
|
||||
# Declare Kind/Size
|
||||
## -> Kind: Value prevents user-alteration of config.
|
||||
## -> Size: Always scalar, since symbols are scalar (for now).
|
||||
'active_kind': FlowKind.Value if not use_range else FlowKind.Range,
|
||||
'size': spux.NumberSize1D.Scalar,
|
||||
# Declare MathType/PhysicalType
|
||||
## -> MathType: Lookup symbol name in info dimensions.
|
||||
## -> PhysicalType: Same.
|
||||
'mathtype': self.dims[sim_sym].mathtype,
|
||||
'physical_type': self.dims[sim_sym].physical_type,
|
||||
# TODO: Default Value
|
||||
# FlowKind.Value: Default Value
|
||||
#'default_value':
|
||||
# FlowKind.Range: Default Min/Max/Steps
|
||||
'default_min': sim_sym.domain.start,
|
||||
'default_max': sim_sym.domain.end,
|
||||
'default_steps': 50,
|
||||
}
|
||||
for sim_sym in self.sorted_symbols
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
|
|||
# Outputs
|
||||
Viewer = enum.auto()
|
||||
## Outputs / File Exporters
|
||||
DataFileExporter = enum.auto()
|
||||
Tidy3DWebExporter = enum.auto()
|
||||
## Outputs / Web Exporters
|
||||
JSONFileExporter = enum.auto()
|
||||
|
|
|
@ -19,13 +19,21 @@
|
|||
import dataclasses
|
||||
import enum
|
||||
import typing as typ
|
||||
from pathlib import Path
|
||||
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
import jaxtyping as jtyp
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import tidy3d as td
|
||||
|
||||
from blender_maxwell.contracts import BLEnumElement
|
||||
from blender_maxwell.services import tdcloud
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
from .flow_kinds.info import InfoFlow
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
####################
|
||||
|
@ -293,3 +301,321 @@ class NewSimCloudTask:
|
|||
|
||||
task_name: tdcloud.CloudTaskName
|
||||
cloud_folder: tdcloud.CloudFolder
|
||||
|
||||
|
||||
####################
|
||||
# - Data File
|
||||
####################
|
||||
_DATA_FILE_EXTS = {
|
||||
'.txt',
|
||||
'.txt.gz',
|
||||
'.csv',
|
||||
'.npy',
|
||||
}
|
||||
|
||||
|
||||
class DataFileFormat(enum.StrEnum):
|
||||
"""Abstraction of a data file format, providing a regularized way of interacting with filesystem data.
|
||||
|
||||
Import/export interacts closely with the `Expr` socket's `FlowKind` semantics:
|
||||
- `FlowKind.Func`: Generally realized on-import/export.
|
||||
- **Import**: Loading data is generally eager, but memory-mapped file loading would be manageable using this interface.
|
||||
- **Export**: The function is realized and only the array is inserted into the file.
|
||||
- `FlowKind.Params`: Generally consumed.
|
||||
- **Import**: A new, empty `ParamsFlow` object is created.
|
||||
- **Export**: The `ParamsFlow` is consumed when realizing the `Func`.
|
||||
- `FlowKind.Info`: As the most important element, it is kept in an (optional) sidecar metadata file.
|
||||
- **Import**: The sidecar file is loaded, checked, and used, if it exists. A warning about further processing may show if it doesn't.
|
||||
- **Export**: The sidecar file is written next to the canonical data file, in such a manner that it can be both read and loaded.
|
||||
|
||||
Notes:
|
||||
This enum is UI Compatible, ex. for nodes/sockets desiring a dropdown menu of data file formats.
|
||||
|
||||
Attributes:
|
||||
Txt: Simple no-header text file.
|
||||
Only supports 1D/2D data.
|
||||
TxtGz: Identical to `Txt`, but compressed with `gzip`.
|
||||
Csv: Unspecific "Comma Separated Values".
|
||||
For loading, `pandas`-default semantics are used.
|
||||
For saving, very opinionated defaults are used.
|
||||
Customization is disabled on purpose.
|
||||
Npy: Generic numpy representation.
|
||||
Supports all kinds of numpy objects.
|
||||
Better laziness support via `jax`.
|
||||
"""
|
||||
|
||||
Csv = enum.auto()
|
||||
Npy = enum.auto()
|
||||
Txt = enum.auto()
|
||||
TxtGz = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
"""The extension name of the given `DataFileFormat`.
|
||||
|
||||
Notes:
|
||||
Called by the UI when creating an `EnumProperty` dropdown.
|
||||
"""
|
||||
return DataFileFormat(v).extension
|
||||
|
||||
@staticmethod
|
||||
def to_icon(v: typ.Self) -> str:
|
||||
"""No icon.
|
||||
|
||||
Notes:
|
||||
Called by the UI when creating an `EnumProperty` dropdown.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def bl_enum_element(self, i: int) -> BLEnumElement:
|
||||
"""Produce a fully functional Blender enum element, given a particular integer index."""
|
||||
return (
|
||||
str(self),
|
||||
DataFileFormat.to_name(self),
|
||||
DataFileFormat.to_name(self),
|
||||
DataFileFormat.to_icon(self),
|
||||
i,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bl_enum_elements() -> list[BLEnumElement]:
|
||||
"""Produce an immediately usable list of Blender enum elements, correctly indexed."""
|
||||
return [
|
||||
data_file_format.bl_enum_element(i)
|
||||
for i, data_file_format in enumerate(list(DataFileFormat))
|
||||
]
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def extension(self) -> str:
|
||||
"""Map to the actual string extension."""
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: '.csv',
|
||||
E.Npy: '.npy',
|
||||
E.Txt: '.txt',
|
||||
E.TxtGz: '.txt.gz',
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Creation: Compatibility
|
||||
####################
|
||||
@staticmethod
|
||||
def valid_exts() -> list[str]:
|
||||
return _DATA_FILE_EXTS
|
||||
|
||||
@staticmethod
|
||||
def ext_has_valid_format(ext: str) -> bool:
|
||||
return ext in _DATA_FILE_EXTS
|
||||
|
||||
@staticmethod
|
||||
def path_has_valid_format(path: Path) -> bool:
|
||||
return path.is_file() and DataFileFormat.ext_has_valid_format(
|
||||
''.join(path.suffixes)
|
||||
)
|
||||
|
||||
def is_path_compatible(
|
||||
self, path: Path, must_exist: bool = False, can_exist: bool = True
|
||||
) -> bool:
|
||||
ext_matches = self.extension == ''.join(path.suffixes)
|
||||
match (must_exist, can_exist):
|
||||
case (False, False):
|
||||
return ext_matches and not path.is_file() and path.parent.is_dir()
|
||||
|
||||
case (True, False):
|
||||
msg = f'DataFileFormat: Path {path} cannot both be required to exist (must_exist=True), but also not be allowed to exist (can_exist=False)'
|
||||
raise ValueError(msg)
|
||||
|
||||
case (False, True):
|
||||
return ext_matches and path.parent.is_dir()
|
||||
|
||||
case (True, True):
|
||||
return ext_matches and path.is_file()
|
||||
|
||||
####################
|
||||
# - Creation
|
||||
####################
|
||||
@staticmethod
|
||||
def from_ext(ext: str) -> typ.Self | None:
|
||||
return {
|
||||
_ext: _data_file_ext
|
||||
for _data_file_ext, _ext in {
|
||||
k: k.extension for k in list(DataFileFormat)
|
||||
}.items()
|
||||
}.get(ext)
|
||||
|
||||
@staticmethod
|
||||
def from_path(path: Path) -> typ.Self | None:
|
||||
if DataFileFormat.path_has_valid_format(path):
|
||||
data_file_ext = DataFileFormat.from_ext(''.join(path.suffixes))
|
||||
if data_file_ext is not None:
|
||||
return data_file_ext
|
||||
|
||||
msg = f'DataFileFormat: Path "{path}" is compatible, but could not find valid extension'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Functions: Metadata
|
||||
####################
|
||||
def supports_metadata(self) -> bool:
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: False, ## No RFC 4180 Support for Comments
|
||||
E.Npy: False, ## Quite simply no support
|
||||
E.Txt: True, ## Use # Comments
|
||||
E.TxtGz: True, ## Same as Txt
|
||||
}[self]
|
||||
|
||||
## TODO: Sidecar Metadata
|
||||
## - The vision is that 'saver' also writes metadata.
|
||||
## - This metadata is essentially a straight serialization of the InfoFlow.
|
||||
## - On-load, the metadata is used to re-generate the InfoFlow.
|
||||
## - This allows interpreting saved data without a ton of shenanigans.
|
||||
## - These sidecars could also be hand-writable for external data.
|
||||
## - When sidecars aren't found, the user would "fill in the blanks".
|
||||
## - ...Thus achieving the same result as if there were a sidecar.
|
||||
|
||||
####################
|
||||
# - Functions: Saver
|
||||
####################
|
||||
def is_info_compatible(self, info: InfoFlow) -> bool:
|
||||
E = DataFileFormat
|
||||
match self:
|
||||
case E.Csv:
|
||||
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
|
||||
case E.Npy:
|
||||
return True
|
||||
case E.Txt | E.TxtGz:
|
||||
return len(info.dims) + info.output.rows + info.output.cols - 1 <= 2
|
||||
|
||||
@property
|
||||
def saver(
|
||||
self,
|
||||
) -> typ.Callable[[Path, jtyp.Shaped[jtyp.Array, '...'], InfoFlow], None]:
|
||||
def save_txt(path, data, info):
|
||||
np.savetxt(path, data)
|
||||
|
||||
def save_txt_gz(path, data, info):
|
||||
np.savetxt(path, data)
|
||||
|
||||
def save_csv(path, data, info):
|
||||
data_np = np.array(data)
|
||||
|
||||
# Extract Input Coordinates
|
||||
dim_columns = {
|
||||
dim.name: np.array(dim_idx.realize_array)
|
||||
for i, (dim, dim_idx) in enumerate(info.dims)
|
||||
} ## TODO: realize_array might not be defined on some index arrays
|
||||
|
||||
# Declare Function to Extract Output Values
|
||||
output_columns = {}
|
||||
|
||||
def declare_output_col(data_col, output_idx=0, use_output_idx=False):
|
||||
nonlocal output_columns
|
||||
|
||||
# Complex: Split to Two Columns
|
||||
output_idx_str = f'[{output_idx}]' if use_output_idx else ''
|
||||
if bool(np.any(np.iscomplex(data_col))):
|
||||
output_columns |= {
|
||||
f'{info.output.name}{output_idx_str}_re': np.real(data_col),
|
||||
f'{info.output.name}{output_idx_str}_im': np.imag(data_col),
|
||||
}
|
||||
|
||||
# Else: Use Array Directly
|
||||
else:
|
||||
output_columns |= {
|
||||
f'{info.output.name}{output_idx_str}': data_col,
|
||||
}
|
||||
|
||||
## TODO: Maybe a check to ensure dtype!=object?
|
||||
|
||||
# Extract Output Values
|
||||
## -> 2D: Iterate over columns by-index.
|
||||
## -> 1D: Declare the array as the only column.
|
||||
if len(data_np.shape) == 2:
|
||||
for output_idx in data_np.shape[1]:
|
||||
declare_output_col(data_np[:, output_idx], output_idx, True)
|
||||
else:
|
||||
declare_output_col(data_np)
|
||||
|
||||
# Compute DataFrame & Write CSV
|
||||
df = pl.DataFrame(dim_columns | output_columns)
|
||||
|
||||
log.debug('Writing Polars DataFrame to CSV:')
|
||||
log.debug(df)
|
||||
df.write_csv(path)
|
||||
|
||||
def save_npy(path, data, info):
|
||||
jnp.save(path, data)
|
||||
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: save_csv,
|
||||
E.Npy: save_npy,
|
||||
E.Txt: save_txt,
|
||||
E.TxtGz: save_txt_gz,
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Functions: Loader
|
||||
####################
|
||||
@property
|
||||
def loader_is_jax_compatible(self) -> bool:
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: False,
|
||||
E.Npy: True,
|
||||
E.Txt: True,
|
||||
E.TxtGz: True,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def loader(
|
||||
self,
|
||||
) -> typ.Callable[[Path], tuple[jtyp.Shaped[jtyp.Array, '...'], InfoFlow]]:
|
||||
def load_txt(path: Path):
|
||||
return jnp.asarray(np.loadtxt(path))
|
||||
|
||||
def load_csv(path: Path):
|
||||
return jnp.asarray(pl.read_csv(path).to_numpy())
|
||||
## TODO: The very next Polars (0.20.27) has a '.to_jax' method!
|
||||
|
||||
def load_npy(path: Path):
|
||||
return jnp.load(path)
|
||||
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: load_csv,
|
||||
E.Npy: load_npy,
|
||||
E.Txt: load_txt,
|
||||
E.TxtGz: load_txt,
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Metadata: Compatibility
|
||||
####################
|
||||
def is_info_compatible(self, info: InfoFlow) -> bool:
|
||||
E = DataFileFormat
|
||||
match self:
|
||||
case E.Csv:
|
||||
return len(info.dims) + (info.output.rows + input.outputs.cols - 1) <= 2
|
||||
case E.Npy:
|
||||
return True
|
||||
case E.Txt | E.TxtGz:
|
||||
return len(info.dims) + (info.output.rows + info.output.cols - 1) <= 2
|
||||
|
||||
def supports_metadata(self) -> bool:
|
||||
E = DataFileFormat
|
||||
return {
|
||||
E.Csv: False, ## No RFC 4180 Support for Comments
|
||||
E.Npy: False, ## Quite simply no support
|
||||
E.Txt: True, ## Use # Comments
|
||||
E.TxtGz: True, ## Same as Txt
|
||||
}[self]
|
||||
|
|
|
@ -48,7 +48,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
|
|||
# Electrodynamics
|
||||
_PT.CurrentDensity: spu.ampere / spu.um**2,
|
||||
_PT.Conductivity: spu.siemens / spu.um,
|
||||
_PT.PoyntingVector: spu.watt / spu.um**2,
|
||||
_PT.EField: spu.volt / spu.um,
|
||||
_PT.HField: spu.ampere / spu.um,
|
||||
# Mechanical
|
||||
|
@ -58,7 +57,6 @@ UNITS_BLENDER: spux.UnitSystem = spux.UNITS_SI | {
|
|||
_PT.Force: spux.micronewton,
|
||||
# Luminal
|
||||
# Optics
|
||||
_PT.PoyntingVector: spu.watt / spu.um**2,
|
||||
} ## TODO: Load (dynamically?) from addon preferences
|
||||
|
||||
UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
|
||||
|
@ -75,11 +73,9 @@ UNITS_TIDY3D: spux.UnitSystem = spux.UNITS_SI | {
|
|||
# Electrodynamics
|
||||
_PT.CurrentDensity: spu.ampere / spu.um**2,
|
||||
_PT.Conductivity: spu.siemens / spu.um,
|
||||
_PT.PoyntingVector: spu.watt / spu.um**2,
|
||||
_PT.EField: spu.volt / spu.um,
|
||||
_PT.HField: spu.ampere / spu.um,
|
||||
# Luminal
|
||||
# Optics
|
||||
_PT.PoyntingVector: spu.watt / spu.um**2,
|
||||
## NOTE: w/o source normalization, EField/HField/Modal amps are * 1/Hz
|
||||
}
|
||||
|
|
|
@ -17,15 +17,15 @@
|
|||
"""Implements `ExtractDataNode`."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import sympy.physics.units as spu
|
||||
import tidy3d as td
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from ... import contracts as ct
|
||||
|
@ -37,6 +37,176 @@ log = logger.get(__name__)
|
|||
TDMonitorData: typ.TypeAlias = td.components.data.monitor_data.MonitorData
|
||||
|
||||
|
||||
####################
|
||||
# - Monitor Label Arrays
|
||||
####################
|
||||
def valid_monitor_attrs(sim_data: td.SimulationData, monitor_name: str) -> list[str]:
|
||||
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
|
||||
|
||||
Parameters:
|
||||
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
|
||||
"""
|
||||
monitor_data = sim_data.monitor_data[monitor_name]
|
||||
monitor_type = monitor_data.type
|
||||
|
||||
match monitor_type:
|
||||
case 'Field' | 'FieldTime' | 'Mode':
|
||||
## TODO: flux, poynting, intensity
|
||||
return [
|
||||
field_component
|
||||
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
|
||||
if getattr(monitor_data, field_component, None) is not None
|
||||
]
|
||||
|
||||
case 'Permittivity':
|
||||
return ['eps_xx', 'eps_yy', 'eps_zz']
|
||||
|
||||
case 'Flux' | 'FluxTime':
|
||||
return ['flux']
|
||||
|
||||
case (
|
||||
'FieldProjectionAngle'
|
||||
| 'FieldProjectionCartesian'
|
||||
| 'FieldProjectionKSpace'
|
||||
| 'Diffraction'
|
||||
):
|
||||
return [
|
||||
'Er',
|
||||
'Etheta',
|
||||
'Ephi',
|
||||
'Hr',
|
||||
'Htheta',
|
||||
'Hphi',
|
||||
]
|
||||
|
||||
|
||||
def extract_info(monitor_data, monitor_attr: str) -> ct.InfoFlow | None: # noqa: PLR0911
|
||||
"""Extract an InfoFlow encapsulating raw data contained in an attribute of the given monitor data."""
|
||||
xarr = getattr(monitor_data, monitor_attr, None)
|
||||
if xarr is None:
|
||||
return None
|
||||
|
||||
def mk_idx_array(axis: str) -> ct.ArrayFlow:
|
||||
return ct.ArrayFlow(
|
||||
values=xarr.get_index(axis).values,
|
||||
unit=symbols[axis].unit,
|
||||
is_sorted=True,
|
||||
)
|
||||
|
||||
# Compute InfoFlow from XArray
|
||||
symbols = {
|
||||
# Cartesian
|
||||
'x': sim_symbols.space_x(spu.micrometer),
|
||||
'y': sim_symbols.space_y(spu.micrometer),
|
||||
'z': sim_symbols.space_z(spu.micrometer),
|
||||
# Spherical
|
||||
'r': sim_symbols.ang_r(spu.micrometer),
|
||||
'theta': sim_symbols.ang_theta(spu.radian),
|
||||
'phi': sim_symbols.ang_phi(spu.radian),
|
||||
# Freq|Time
|
||||
'f': sim_symbols.freq(spu.hertz),
|
||||
't': sim_symbols.t(spu.second),
|
||||
# Power Flux
|
||||
'flux': sim_symbols.flux(spu.watt),
|
||||
# Cartesian Fields
|
||||
'Ex': sim_symbols.field_ex(spu.volt / spu.micrometer),
|
||||
'Ey': sim_symbols.field_ey(spu.volt / spu.micrometer),
|
||||
'Ez': sim_symbols.field_ez(spu.volt / spu.micrometer),
|
||||
'Hx': sim_symbols.field_hx(spu.volt / spu.micrometer),
|
||||
'Hy': sim_symbols.field_hy(spu.volt / spu.micrometer),
|
||||
'Hz': sim_symbols.field_hz(spu.volt / spu.micrometer),
|
||||
# Spherical Fields
|
||||
'Er': sim_symbols.field_er(spu.volt / spu.micrometer),
|
||||
'Etheta': sim_symbols.ang_theta(spu.volt / spu.micrometer),
|
||||
'Ephi': sim_symbols.field_ez(spu.volt / spu.micrometer),
|
||||
'Hr': sim_symbols.field_hr(spu.volt / spu.micrometer),
|
||||
'Htheta': sim_symbols.field_hy(spu.volt / spu.micrometer),
|
||||
'Hphi': sim_symbols.field_hz(spu.volt / spu.micrometer),
|
||||
# Wavevector
|
||||
'ux': sim_symbols.dir_x(spu.watt),
|
||||
'uy': sim_symbols.dir_y(spu.watt),
|
||||
# Diffraction Orders
|
||||
'orders_x': sim_symbols.diff_order_x(None),
|
||||
'orders_y': sim_symbols.diff_order_y(None),
|
||||
}
|
||||
|
||||
match monitor_data.type:
|
||||
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['x']: mk_idx_array('x'),
|
||||
symbols['y']: mk_idx_array('y'),
|
||||
symbols['z']: mk_idx_array('z'),
|
||||
symbols['f']: mk_idx_array('f'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'FieldTime':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['x']: mk_idx_array('x'),
|
||||
symbols['y']: mk_idx_array('y'),
|
||||
symbols['z']: mk_idx_array('z'),
|
||||
symbols['t']: mk_idx_array('t'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'Flux':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['f']: mk_idx_array('f'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'FluxTime':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['t']: mk_idx_array('t'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'FieldProjectionAngle':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['r']: mk_idx_array('r'),
|
||||
symbols['theta']: mk_idx_array('theta'),
|
||||
symbols['phi']: mk_idx_array('phi'),
|
||||
symbols['f']: mk_idx_array('f'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'FieldProjectionKSpace':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['ux']: mk_idx_array('ux'),
|
||||
symbols['uy']: mk_idx_array('uy'),
|
||||
symbols['r']: mk_idx_array('r'),
|
||||
symbols['f']: mk_idx_array('f'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
case 'Diffraction':
|
||||
return ct.InfoFlow(
|
||||
dims={
|
||||
symbols['orders_x']: mk_idx_array('orders_x'),
|
||||
symbols['orders_y']: mk_idx_array('orders_y'),
|
||||
symbols['f']: mk_idx_array('f'),
|
||||
},
|
||||
output=symbols[monitor_attr],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class ExtractDataNode(base.MaxwellSimNode):
|
||||
"""Extract data from sockets for further analysis.
|
||||
|
||||
|
@ -45,33 +215,21 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
Monitor Data: Extract `Expr`s from monitor data by-component.
|
||||
|
||||
Attributes:
|
||||
extract_filter: Identifier for data to extract from the input.
|
||||
monitor_attr: Identifier for data to extract from the input.
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.ExtractData
|
||||
bl_label = 'Extract'
|
||||
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Sim Data': {'Sim Data': sockets.MaxwellFDTDSimDataSocketDef()},
|
||||
'Monitor Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
|
||||
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
|
||||
}
|
||||
output_socket_sets: typ.ClassVar = {
|
||||
'Sim Data': {'Monitor Data': sockets.MaxwellMonitorDataSocketDef()},
|
||||
'Monitor Data': {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc)
|
||||
},
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
extract_filter: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_extract_filters(),
|
||||
cb_depends_on={'sim_data_monitor_nametype', 'monitor_data_type'},
|
||||
)
|
||||
|
||||
####################
|
||||
# - Computed: Sim Data
|
||||
# - Properties: Monitor Name
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='Sim Data',
|
||||
|
@ -101,198 +259,49 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
|
||||
@bl_cache.cached_bl_property(depends_on={'sim_data'})
|
||||
def sim_data_monitor_nametype(self) -> dict[str, str] | None:
|
||||
"""For simulation data, deduces a map from the monitor name to the monitor "type".
|
||||
"""Dictionary from monitor names on `self.sim_data` to their associated type name (with suffix 'Data' removed).
|
||||
|
||||
Return:
|
||||
The name to type of monitors in the simulation data.
|
||||
"""
|
||||
if self.sim_data is not None:
|
||||
return {
|
||||
monitor_name: monitor_data.type
|
||||
monitor_name: monitor_data.type.removesuffix('Data')
|
||||
for monitor_name, monitor_data in self.sim_data.monitor_data.items()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Computed Properties: Monitor Data
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='Monitor Data',
|
||||
input_sockets={'Monitor Data'},
|
||||
input_sockets_optional={'Monitor Data': True},
|
||||
monitor_name: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_monitor_names(),
|
||||
cb_depends_on={'sim_data_monitor_nametype'},
|
||||
)
|
||||
def on_monitor_data_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_monitor_data = not ct.FlowSignal.check(input_sockets['Monitor Data'])
|
||||
if has_monitor_data:
|
||||
self.monitor_data = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def monitor_data(self) -> TDMonitorData | None:
|
||||
"""Extracts the monitor data from the input socket.
|
||||
|
||||
Return:
|
||||
Either the monitor data, if available, or None.
|
||||
"""
|
||||
monitor_data = self._compute_input(
|
||||
'Monitor Data', kind=ct.FlowKind.Value, optional=True
|
||||
)
|
||||
has_monitor_data = not ct.FlowSignal.check(monitor_data)
|
||||
if has_monitor_data:
|
||||
return monitor_data
|
||||
|
||||
return None
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'monitor_data'})
|
||||
def monitor_data_type(self) -> str | None:
|
||||
r"""For monitor data, deduces the monitor "type".
|
||||
|
||||
- **Field(Time)**: A monitor storing values/pixels/voxels with electromagnetic field values, on the time or frequency domain.
|
||||
- **Permittivity**: A monitor storing values/pixels/voxels containing the diagonal of the relative permittivity tensor.
|
||||
- **Flux(Time)**: A monitor storing the directional flux on the time or frequency domain.
|
||||
For planes, an explicit direction is defined.
|
||||
For volumes, the the integral of all outgoing energy is stored.
|
||||
- **FieldProjection(...)**: A monitor storing the spherical-coordinate electromagnetic field components of a near-to-far-field projection.
|
||||
- **Diffraction**: A monitor storing a near-to-far-field projection by diffraction order.
|
||||
def search_monitor_names(self) -> list[ct.BLEnumElement]:
|
||||
"""Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
|
||||
|
||||
Notes:
|
||||
Should be invalidated with (before) `self.monitor_data_attrs`.
|
||||
|
||||
Return:
|
||||
The "type" of the monitor, if available, else None.
|
||||
"""
|
||||
if self.monitor_data is not None:
|
||||
return self.monitor_data.type.removesuffix('Data')
|
||||
|
||||
return None
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'monitor_data_type'})
|
||||
def monitor_data_attrs(self) -> list[str] | None:
|
||||
r"""For monitor data, deduces the valid data-containing attributes.
|
||||
|
||||
The output depends entirely on the output of `self.monitor_data_type`, since the valid attributes of each monitor type is well-defined without needing to perform dynamic lookups.
|
||||
|
||||
- **Field(Time)**: Whichever `[E|H][x|y|z]` are not `None` on the monitor.
|
||||
- **Permittivity**: Specifically `['xx', 'yy', 'zz']`.
|
||||
- **Flux(Time)**: Only `['flux']`.
|
||||
- **FieldProjection(...)**: All of $r$, $\theta$, $\phi$ for both `E` and `H`.
|
||||
- **Diffraction**: Same as `FieldProjection`.
|
||||
|
||||
Notes:
|
||||
Should be invalidated after with `self.monitor_data_type`.
|
||||
|
||||
Return:
|
||||
The "type" of the monitor, if available, else None.
|
||||
"""
|
||||
if self.monitor_data is not None:
|
||||
# Field/FieldTime
|
||||
if self.monitor_data_type in ['Field', 'FieldTime']:
|
||||
return [
|
||||
field_component
|
||||
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
|
||||
if hasattr(self.monitor_data, field_component)
|
||||
]
|
||||
|
||||
# Permittivity
|
||||
if self.monitor_data_type == 'Permittivity':
|
||||
return ['xx', 'yy', 'zz']
|
||||
|
||||
# Flux/FluxTime
|
||||
if self.monitor_data_type in ['Flux', 'FluxTime']:
|
||||
return ['flux']
|
||||
|
||||
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
|
||||
if self.monitor_data_type in [
|
||||
'FieldProjectionAngle',
|
||||
'FieldProjectionCartesian',
|
||||
'FieldProjectionKSpace',
|
||||
'Diffraction',
|
||||
]:
|
||||
return [
|
||||
'Er',
|
||||
'Etheta',
|
||||
'Ephi',
|
||||
'Hr',
|
||||
'Htheta',
|
||||
'Hphi',
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Extraction Filter Search
|
||||
####################
|
||||
def search_extract_filters(self) -> list[ct.BLEnumElement]:
|
||||
"""Compute valid values for `self.extract_filter`, for a dynamic `EnumProperty`.
|
||||
|
||||
Notes:
|
||||
Should be reset (via `self.extract_filter`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
|
||||
Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametype`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
|
||||
|
||||
See `bl_cache.BLField` for more on dynamic `EnumProperty`.
|
||||
|
||||
Returns:
|
||||
Valid `self.extract_filter` in a format compatible with dynamic `EnumProperty`.
|
||||
Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
|
||||
"""
|
||||
if self.sim_data_monitor_nametype is not None:
|
||||
return [
|
||||
(monitor_name, monitor_name, monitor_type.removesuffix('Data'), '', i)
|
||||
(
|
||||
monitor_name,
|
||||
monitor_name,
|
||||
monitor_type + ' Monitor Data',
|
||||
'',
|
||||
i,
|
||||
)
|
||||
for i, (monitor_name, monitor_type) in enumerate(
|
||||
self.sim_data_monitor_nametype.items()
|
||||
)
|
||||
]
|
||||
|
||||
if self.monitor_data_attrs is not None:
|
||||
# Field/FieldTime
|
||||
if self.monitor_data_type in ['Field', 'FieldTime']:
|
||||
return [
|
||||
(
|
||||
monitor_attr,
|
||||
monitor_attr,
|
||||
f'ℂ {monitor_attr[1]}-polarization of the {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
|
||||
'',
|
||||
i,
|
||||
)
|
||||
for i, monitor_attr in enumerate(self.monitor_data_attrs)
|
||||
]
|
||||
|
||||
# Permittivity
|
||||
if self.monitor_data_type == 'Permittivity':
|
||||
return [
|
||||
(monitor_attr, monitor_attr, f'ℂ ε_{monitor_attr}', '', i)
|
||||
for i, monitor_attr in enumerate(self.monitor_data_attrs)
|
||||
]
|
||||
|
||||
# Flux/FluxTime
|
||||
if self.monitor_data_type in ['Flux', 'FluxTime']:
|
||||
return [
|
||||
(
|
||||
monitor_attr,
|
||||
monitor_attr,
|
||||
'Power flux integral through the plane / out of the volume',
|
||||
'',
|
||||
i,
|
||||
)
|
||||
for i, monitor_attr in enumerate(self.monitor_data_attrs)
|
||||
]
|
||||
|
||||
# FieldProjection(Angle/Cartesian/KSpace)/Diffraction
|
||||
if self.monitor_data_type in [
|
||||
'FieldProjectionAngle',
|
||||
'FieldProjectionCartesian',
|
||||
'FieldProjectionKSpace',
|
||||
'Diffraction',
|
||||
]:
|
||||
return [
|
||||
(
|
||||
monitor_attr,
|
||||
monitor_attr,
|
||||
f'ℂ {monitor_attr[1]}-component of the spherical {"electric" if monitor_attr[0] == "E" else "magnetic"} field',
|
||||
'',
|
||||
i,
|
||||
)
|
||||
for i, monitor_attr in enumerate(self.monitor_data_attrs)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
####################
|
||||
|
@ -305,10 +314,9 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
Called by Blender to determine the text to place in the node's header.
|
||||
"""
|
||||
has_sim_data = self.sim_data_monitor_nametype is not None
|
||||
has_monitor_data = self.monitor_data_attrs is not None
|
||||
|
||||
if has_sim_data or has_monitor_data:
|
||||
return f'Extract: {self.extract_filter}'
|
||||
if has_sim_data:
|
||||
return f'Extract: {self.monitor_name}'
|
||||
|
||||
return self.bl_label
|
||||
|
||||
|
@ -318,340 +326,115 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
Parameters:
|
||||
col: UI target for drawing.
|
||||
"""
|
||||
col.prop(self, self.blfields['extract_filter'], text='')
|
||||
col.prop(self, self.blfields['monitor_name'], text='')
|
||||
|
||||
####################
|
||||
# - FlowKind.Value: Sim Data -> Monitor Data
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Monitor Data',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'extract_filter'},
|
||||
input_sockets={'Sim Data'},
|
||||
input_sockets_optional={'Sim Data': True},
|
||||
)
|
||||
def compute_monitor_data(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> TDMonitorData | ct.FlowSignal:
|
||||
"""Compute `Monitor Data` by querying the attribute of `Sim Data` referenced by the property `self.extract_filter`.
|
||||
|
||||
Returns:
|
||||
Monitor data, if available, else `ct.FlowSignal.FlowPending`.
|
||||
"""
|
||||
extract_filter = props['extract_filter']
|
||||
sim_data = input_sockets['Sim Data']
|
||||
has_sim_data = not ct.FlowSignal.check(sim_data)
|
||||
|
||||
if has_sim_data and extract_filter is not None:
|
||||
return sim_data.monitor_data[extract_filter]
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Array|LazyValueFunc: Monitor Data -> Expr
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Array,
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'extract_filter'},
|
||||
input_sockets={'Monitor Data'},
|
||||
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
|
||||
input_sockets_optional={'Monitor Data': True},
|
||||
props={'monitor_name'},
|
||||
input_sockets={'Sim Data'},
|
||||
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
|
||||
)
|
||||
def compute_expr(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> jax.Array | ct.FlowSignal:
|
||||
"""Compute `Expr:Array` by querying an array-like attribute of `Monitor Data`, then constructing an `ct.ArrayFlow` around it.
|
||||
) -> ct.FuncFlow | ct.FlowSignal:
|
||||
sim_data = input_sockets['Sim Data']
|
||||
monitor_name = props['monitor_name']
|
||||
|
||||
Uses the internal `xarray` data returned by Tidy3D.
|
||||
By using `np.array` on the `.data` attribute of the `xarray`, instead of the usual JAX array constructor, we should save a (possibly very big) copy.
|
||||
has_sim_data = not ct.FlowSignal.check(sim_data)
|
||||
|
||||
Returns:
|
||||
The data array, if available, else `ct.FlowSignal.FlowPending`.
|
||||
"""
|
||||
extract_filter = props['extract_filter']
|
||||
monitor_data = input_sockets['Monitor Data']
|
||||
has_monitor_data = not ct.FlowSignal.check(monitor_data)
|
||||
if has_sim_data and monitor_name is not None:
|
||||
monitor_data = sim_data.get(monitor_name)
|
||||
if monitor_data is not None:
|
||||
# Extract Valid Index Labels
|
||||
## -> The first output axis will be integer-indexed.
|
||||
## -> Each integer will have a string label.
|
||||
## -> Those string labels explain the integer as ex. Ex, Ey, Hy.
|
||||
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
|
||||
|
||||
if has_monitor_data and extract_filter is not None:
|
||||
xarray_data = getattr(monitor_data, extract_filter)
|
||||
return ct.ArrayFlow(values=np.array(xarray_data.data), unit=None)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
# Trigger
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
# Loaded
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Array},
|
||||
output_sockets_optional={'Expr': True},
|
||||
)
|
||||
def compute_extracted_data_lazy(
|
||||
self, output_sockets: dict
|
||||
) -> ct.LazyValueFuncFlow | None:
|
||||
"""Declare `Expr:LazyValueFunc` by creating a simple function that directly wraps `Expr:Array`.
|
||||
|
||||
Returns:
|
||||
The composable function array, if available, else `ct.FlowSignal.FlowPending`.
|
||||
"""
|
||||
output_expr = output_sockets['Expr']
|
||||
has_output_expr = not ct.FlowSignal.check(output_expr)
|
||||
|
||||
if has_output_expr:
|
||||
return ct.LazyValueFuncFlow(
|
||||
func=lambda: output_expr.values, supports_jax=True
|
||||
)
|
||||
# Generate FuncFlow Per Index Label
|
||||
## -> We extract each XArray as an attribute of monitor_data.
|
||||
## -> We then bind its values into a unique func_flow.
|
||||
## -> This lets us 'stack' then all along the first axis.
|
||||
func_flows = []
|
||||
for idx_label in idx_labels:
|
||||
xarr = getattr(monitor_data, idx_label)
|
||||
func_flows.append(
|
||||
ct.FuncFlow(
|
||||
func=lambda xarr=xarr: xarr.values,
|
||||
supports_jax=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Concatenate and Stack Unified FuncFlow
|
||||
## -> First, 'reduce' lets us __or__ all the FuncFlows together.
|
||||
## -> Then, 'compose_within' lets us stack them along axis=0.
|
||||
## -> The "new" axis=0 is int-indexed axis w/idx_labels labels!
|
||||
return functools.reduce(lambda a, b: a | b, func_flows).compose_within(
|
||||
enclosing_func=lambda data: jnp.stack(data, axis=0)
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params: Monitor Data -> Expr
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
input_sockets={'Sim Data'},
|
||||
input_socket_kinds={'Sim Data': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_data_params(self) -> ct.ParamsFlow:
|
||||
def compute_data_params(self, input_sockets) -> ct.ParamsFlow:
|
||||
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
|
||||
|
||||
Returns:
|
||||
A completely empty `ParamsFlow`, ready to be composed.
|
||||
"""
|
||||
sim_params = input_sockets['Sim Data']
|
||||
has_sim_params = not ct.FlowSignal.check(sim_params)
|
||||
|
||||
if has_sim_params:
|
||||
return sim_params
|
||||
return ct.ParamsFlow()
|
||||
|
||||
####################
|
||||
# - FlowKind.Info: Monitor Data -> Expr
|
||||
# - FlowKind.Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
# Loaded
|
||||
props={'monitor_data_type', 'extract_filter'},
|
||||
input_sockets={'Monitor Data'},
|
||||
input_socket_kinds={'Monitor Data': ct.FlowKind.Value},
|
||||
input_sockets_optional={'Monitor Data': True},
|
||||
props={'monitor_name'},
|
||||
input_sockets={'Sim Data'},
|
||||
input_socket_kinds={'Sim Data': ct.FlowKind.Value},
|
||||
)
|
||||
def compute_extracted_data_info(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> ct.InfoFlow:
|
||||
def compute_extracted_data_info(self, props, input_sockets) -> ct.InfoFlow:
|
||||
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
|
||||
|
||||
Returns:
|
||||
Information describing the `Data:LazyValueFunc`, if available, else `ct.FlowSignal.FlowPending`.
|
||||
Information describing the `Data:Func`, if available, else `ct.FlowSignal.FlowPending`.
|
||||
"""
|
||||
monitor_data = input_sockets['Monitor Data']
|
||||
monitor_data_type = props['monitor_data_type']
|
||||
extract_filter = props['extract_filter']
|
||||
sim_data = input_sockets['Sim Data']
|
||||
monitor_name = props['monitor_name']
|
||||
|
||||
has_monitor_data = not ct.FlowSignal.check(monitor_data)
|
||||
has_sim_data = not ct.FlowSignal.check(sim_data)
|
||||
|
||||
# Edge Case: Dangling 'flux' Access on 'FieldMonitor'
|
||||
## -> Sometimes works - UNLESS the FieldMonitor doesn't have all fields.
|
||||
## -> We don't allow 'flux' attribute access, but it can dangle.
|
||||
## -> (The method is called when updating each depschain component.)
|
||||
if monitor_data_type == 'Field' and extract_filter == 'flux':
|
||||
if not has_sim_data or monitor_name is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Retrieve XArray
|
||||
if has_monitor_data and extract_filter is not None:
|
||||
xarr = getattr(monitor_data, extract_filter, None)
|
||||
if xarr is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
else:
|
||||
return ct.FlowSignal.FlowPending
|
||||
# Extract Data
|
||||
## -> All monitor_data.<idx_label> have the exact same InfoFlow.
|
||||
## -> So, just construct an InfoFlow w/prepended labelled dimension.
|
||||
monitor_data = sim_data.get(monitor_name)
|
||||
idx_labels = valid_monitor_attrs(sim_data, monitor_name)
|
||||
info = extract_info(monitor_data, idx_labels[0])
|
||||
|
||||
# Compute InfoFlow from XArray
|
||||
## XYZF: Field / Permittivity / FieldProjectionCartesian
|
||||
if monitor_data_type in {
|
||||
'Field',
|
||||
'Permittivity',
|
||||
#'FieldProjectionCartesian',
|
||||
}:
|
||||
return ct.InfoFlow(
|
||||
dim_names=['x', 'y', 'z', 'f'],
|
||||
dim_idx={
|
||||
axis: ct.ArrayFlow(
|
||||
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
|
||||
)
|
||||
for axis in ['x', 'y', 'z']
|
||||
}
|
||||
| {
|
||||
'f': ct.ArrayFlow(
|
||||
values=xarr.get_index('f').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Complex,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
|
||||
),
|
||||
)
|
||||
|
||||
## XYZT: FieldTime
|
||||
if monitor_data_type == 'FieldTime':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['x', 'y', 'z', 't'],
|
||||
dim_idx={
|
||||
axis: ct.ArrayFlow(
|
||||
values=xarr.get_index(axis).values, unit=spu.um, is_sorted=True
|
||||
)
|
||||
for axis in ['x', 'y', 'z']
|
||||
}
|
||||
| {
|
||||
't': ct.ArrayFlow(
|
||||
values=xarr.get_index('t').values,
|
||||
unit=spu.second,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Complex,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer if monitor_data_type == 'Field' else None
|
||||
),
|
||||
)
|
||||
|
||||
## F: Flux
|
||||
if monitor_data_type == 'Flux':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['f'],
|
||||
dim_idx={
|
||||
'f': ct.ArrayFlow(
|
||||
values=xarr.get_index('f').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=spu.watt,
|
||||
)
|
||||
|
||||
## T: FluxTime
|
||||
if monitor_data_type == 'FluxTime':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['t'],
|
||||
dim_idx={
|
||||
't': ct.ArrayFlow(
|
||||
values=xarr.get_index('t').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=spu.watt,
|
||||
)
|
||||
|
||||
## RThetaPhiF: FieldProjectionAngle
|
||||
if monitor_data_type == 'FieldProjectionAngle':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['r', 'theta', 'phi', 'f'],
|
||||
dim_idx={
|
||||
'r': ct.ArrayFlow(
|
||||
values=xarr.get_index('r').values,
|
||||
unit=spu.micrometer,
|
||||
is_sorted=True,
|
||||
),
|
||||
}
|
||||
| {
|
||||
c: ct.ArrayFlow(
|
||||
values=xarr.get_index(c).values,
|
||||
unit=spu.radian,
|
||||
is_sorted=True,
|
||||
)
|
||||
for c in ['r', 'theta', 'phi']
|
||||
}
|
||||
| {
|
||||
'f': ct.ArrayFlow(
|
||||
values=xarr.get_index('f').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if extract_filter.startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
),
|
||||
)
|
||||
|
||||
## UxUyRF: FieldProjectionKSpace
|
||||
if monitor_data_type == 'FieldProjectionKSpace':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['ux', 'uy', 'r', 'f'],
|
||||
dim_idx={
|
||||
c: ct.ArrayFlow(
|
||||
values=xarr.get_index(c).values, unit=None, is_sorted=True
|
||||
)
|
||||
for c in ['ux', 'uy']
|
||||
}
|
||||
| {
|
||||
'r': ct.ArrayFlow(
|
||||
values=xarr.get_index('r').values,
|
||||
unit=spu.micrometer,
|
||||
is_sorted=True,
|
||||
),
|
||||
'f': ct.ArrayFlow(
|
||||
values=xarr.get_index('f').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if extract_filter.startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
),
|
||||
)
|
||||
|
||||
## OrderxOrderyF: Diffraction
|
||||
if monitor_data_type == 'Diffraction':
|
||||
return ct.InfoFlow(
|
||||
dim_names=['orders_x', 'orders_y', 'f'],
|
||||
dim_idx={
|
||||
f'orders_{c}': ct.ArrayFlow(
|
||||
values=xarr.get_index(f'orders_{c}').values,
|
||||
unit=None,
|
||||
is_sorted=True,
|
||||
)
|
||||
for c in ['x', 'y']
|
||||
}
|
||||
| {
|
||||
'f': ct.ArrayFlow(
|
||||
values=xarr.get_index('f').values,
|
||||
unit=spu.hertz,
|
||||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
output_name=extract_filter,
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if extract_filter.startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
),
|
||||
)
|
||||
|
||||
msg = f'Unsupported Monitor Data Type {monitor_data_type} in "FlowKind.Info" of "{self.bl_label}"'
|
||||
raise RuntimeError(msg)
|
||||
return info.prepend_dim(sim_symbols.idx, idx_labels)
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -98,29 +98,29 @@ class FilterOperation(enum.StrEnum):
|
|||
operations = []
|
||||
|
||||
# Slice
|
||||
if info.dim_names:
|
||||
if info.dims:
|
||||
operations.append(FO.SliceIdx)
|
||||
|
||||
# Pin
|
||||
## PinLen1
|
||||
## -> There must be a dimension with length 1.
|
||||
if 1 in list(info.dim_lens.values()):
|
||||
if 1 in [dim_idx for dim_idx in info.dims.values() if dim_idx is not None]:
|
||||
operations.append(FO.PinLen1)
|
||||
|
||||
## Pin | PinIdx
|
||||
## -> There must be a dimension, full stop.
|
||||
if info.dim_names:
|
||||
if info.dims:
|
||||
operations += [FO.Pin, FO.PinIdx]
|
||||
|
||||
# Reinterpret
|
||||
## Swap
|
||||
## -> There must be at least two dimensions.
|
||||
if len(info.dim_names) >= 2: # noqa: PLR2004
|
||||
if len(info.dims) >= 2: # noqa: PLR2004
|
||||
operations.append(FO.Swap)
|
||||
|
||||
## SetDim
|
||||
## -> There must be a dimension to correct.
|
||||
if info.dim_names:
|
||||
if info.dims:
|
||||
operations.append(FO.SetDim)
|
||||
|
||||
return operations
|
||||
|
@ -158,33 +158,33 @@ class FilterOperation(enum.StrEnum):
|
|||
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||
FO = FilterOperation
|
||||
match self:
|
||||
case FO.SliceIdx:
|
||||
return info.dim_names
|
||||
case FO.SliceIdx | FO.Swap:
|
||||
return info.dims
|
||||
|
||||
# PinLen1: Only allow dimensions with length=1.
|
||||
case FO.PinLen1:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_lens[dim_name] == 1
|
||||
dim
|
||||
for dim, dim_idx in info.dims.items()
|
||||
if dim_idx is not None and len(dim_idx) == 1
|
||||
]
|
||||
|
||||
# Pin: Only allow dimensions with known indexing.
|
||||
case FO.Pin:
|
||||
# Pin: Only allow dimensions with discrete index.
|
||||
## TODO: Shouldn't 'Pin' be allowed to index continuous indices too?
|
||||
case FO.Pin | FO.PinIdx:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_has_coords[dim_name] != 0
|
||||
dim
|
||||
for dim, dim_idx in info.dims
|
||||
if dim_idx is not None and len(dim_idx) > 0
|
||||
]
|
||||
|
||||
case FO.PinIdx | FO.Swap:
|
||||
return info.dim_names
|
||||
|
||||
case FO.SetDim:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_mathtypes[dim_name] == spux.MathType.Integer
|
||||
dim
|
||||
for dim, dim_idx in info.dims
|
||||
if dim_idx is not None
|
||||
and not isinstance(dim_idx, list)
|
||||
and dim_idx.mathtype == spux.MathType.Integer
|
||||
]
|
||||
|
||||
return []
|
||||
|
@ -224,22 +224,22 @@ class FilterOperation(enum.StrEnum):
|
|||
def transform_info(
|
||||
self,
|
||||
info: ct.InfoFlow,
|
||||
dim_0: str,
|
||||
dim_1: str,
|
||||
dim_0: sim_symbols.SimSymbol,
|
||||
dim_1: sim_symbols.SimSymbol,
|
||||
pin_idx: int | None = None,
|
||||
slice_tuple: tuple[int, int, int] | None = None,
|
||||
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.LazyArrayRangeFlow]]
|
||||
| None = None,
|
||||
replaced_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.RangeFlow]] | None = None,
|
||||
):
|
||||
FO = FilterOperation
|
||||
return {
|
||||
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||
# Pin
|
||||
FO.PinLen1: lambda: info.delete_dimension(dim_0),
|
||||
FO.Pin: lambda: info.delete_dimension(dim_0),
|
||||
FO.PinIdx: lambda: info.delete_dimension(dim_0),
|
||||
FO.PinLen1: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||
FO.Pin: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||
FO.PinIdx: lambda: info.delete_dim(dim_0, pin_idx=pin_idx),
|
||||
# Reinterpret
|
||||
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
||||
FO.SetDim: lambda: info.replace_dim(*corrected_dim),
|
||||
FO.SetDim: lambda: info.replace_dim(*replaced_dim),
|
||||
}[self]()
|
||||
|
||||
|
||||
|
@ -265,10 +265,10 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
bl_label = 'Filter Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -318,11 +318,11 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - Properties: Dimension Selection
|
||||
####################
|
||||
dim_0: enum.StrEnum = bl_cache.BLField(
|
||||
active_dim_0: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_dims(),
|
||||
cb_depends_on={'operation', 'expr_info'},
|
||||
)
|
||||
dim_1: enum.StrEnum = bl_cache.BLField(
|
||||
active_dim_1: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_dims(),
|
||||
cb_depends_on={'operation', 'expr_info'},
|
||||
)
|
||||
|
@ -335,40 +335,23 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
]
|
||||
return []
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'active_dim_0'})
|
||||
def dim_0(self) -> sim_symbols.SimSymbol | None:
|
||||
if self.expr_info is not None and self.active_dim_0 is not None:
|
||||
return self.expr_info.dim_by_name(self.active_dim_0)
|
||||
return None
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'active_dim_1'})
|
||||
def dim_1(self) -> sim_symbols.SimSymbol | None:
|
||||
if self.expr_info is not None and self.active_dim_1 is not None:
|
||||
return self.expr_info.dim_by_name(self.active_dim_1)
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Properties: Slice
|
||||
####################
|
||||
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
|
||||
|
||||
####################
|
||||
# - Properties: Unit
|
||||
####################
|
||||
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
|
||||
sim_symbols.CommonSimSymbol.X
|
||||
)
|
||||
|
||||
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_valid_units(),
|
||||
cb_depends_on={'set_dim_symbol'},
|
||||
)
|
||||
|
||||
def search_valid_units(self) -> list[ct.BLEnumElement]:
|
||||
"""Compute Blender enum elements of valid units for the current `physical_type`."""
|
||||
physical_type = self.set_dim_symbol.sim_symbol.physical_type
|
||||
if physical_type is not spux.PhysicalType.NonPhysical:
|
||||
return [
|
||||
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
|
||||
for i, unit in enumerate(physical_type.valid_units)
|
||||
]
|
||||
return []
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
|
||||
def set_dim_unit(self) -> spux.Unit | None:
|
||||
if self.set_dim_active_unit is not None:
|
||||
return spux.unit_str_to_unit(self.set_dim_active_unit)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
|
@ -378,27 +361,27 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
# Slice
|
||||
case FO.SliceIdx:
|
||||
slice_str = ':'.join([str(v) for v in self.slice_tuple])
|
||||
return f'Filter: {self.dim_0}[{slice_str}]'
|
||||
return f'Filter: {self.active_dim_0}[{slice_str}]'
|
||||
|
||||
# Pin
|
||||
case FO.PinLen1:
|
||||
return f'Filter: Pin {self.dim_0}[0]'
|
||||
return f'Filter: Pin {self.active_dim_0}[0]'
|
||||
case FO.Pin:
|
||||
return f'Filter: Pin {self.dim_0}[...]'
|
||||
return f'Filter: Pin {self.active_dim_0}[...]'
|
||||
case FO.PinIdx:
|
||||
pin_idx_axis = self._compute_input(
|
||||
'Axis', kind=ct.FlowKind.Value, optional=True
|
||||
)
|
||||
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
|
||||
if has_pin_idx_axis:
|
||||
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]'
|
||||
return f'Filter: Pin {self.active_dim_0}[{pin_idx_axis}]'
|
||||
return self.bl_label
|
||||
|
||||
# Reinterpret
|
||||
case FO.Swap:
|
||||
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]'
|
||||
return f'Filter: Swap [{self.active_dim_0}]|[{self.active_dim_1}]'
|
||||
case FO.SetDim:
|
||||
return f'Filter: Set [{self.dim_0}]'
|
||||
return f'Filter: Set [{self.active_dim_0}]'
|
||||
|
||||
case _:
|
||||
return self.bl_label
|
||||
|
@ -409,20 +392,15 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
if self.operation is not None:
|
||||
match self.operation.num_dim_inputs:
|
||||
case 1:
|
||||
layout.prop(self, self.blfields['dim_0'], text='')
|
||||
layout.prop(self, self.blfields['active_dim_0'], text='')
|
||||
case 2:
|
||||
row = layout.row(align=True)
|
||||
row.prop(self, self.blfields['dim_0'], text='')
|
||||
row.prop(self, self.blfields['dim_1'], text='')
|
||||
row.prop(self, self.blfields['active_dim_0'], text='')
|
||||
row.prop(self, self.blfields['active_dim_1'], text='')
|
||||
|
||||
if self.operation is FilterOperation.SliceIdx:
|
||||
layout.prop(self, self.blfields['slice_tuple'], text='')
|
||||
|
||||
if self.operation is FilterOperation.SetDim:
|
||||
row = layout.row(align=True)
|
||||
row.prop(self, self.blfields['set_dim_symbol'], text='')
|
||||
row.prop(self, self.blfields['set_dim_active_unit'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
|
@ -450,50 +428,47 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
if not has_info:
|
||||
return
|
||||
|
||||
# Pin Dim by-Value: Synchronize Input Socket
|
||||
## -> The user will be given a socket w/correct mathtype, unit, etc. .
|
||||
## -> Internally, this value will map to a particular index.
|
||||
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
|
||||
# Deduce Pinned Information
|
||||
pinned_unit = info.dim_units[props['dim_0']]
|
||||
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
|
||||
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
|
||||
wanted_mathtype = (
|
||||
spux.MathType.Complex
|
||||
if pinned_mathtype == spux.MathType.Complex
|
||||
and spux.MathType.Complex in pinned_physical_type.valid_mathtypes
|
||||
else spux.MathType.Real
|
||||
)
|
||||
dim_0 = props['dim_0']
|
||||
|
||||
# Get Current and Wanted Socket Defs
|
||||
## -> 'Value' may already exist. If not, all is well.
|
||||
# Loose Sockets: Pin Dim by-Value
|
||||
## -> Works with continuous / discrete indexes.
|
||||
## -> The user will be given a socket w/correct mathtype, unit, etc. .
|
||||
if (
|
||||
props['operation'] is FilterOperation.Pin
|
||||
and dim_0 is not None
|
||||
and (info.has_idx_cont(dim_0) or info.has_idx_discrete(dim_0))
|
||||
):
|
||||
dim = dim_0
|
||||
current_bl_socket = self.loose_input_sockets.get('Value')
|
||||
|
||||
# Determine Whether to Construct
|
||||
## -> If nothing needs to change, then nothing changes.
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.active_kind != ct.FlowKind.Value
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
or current_bl_socket.physical_type != pinned_physical_type
|
||||
or current_bl_socket.mathtype != wanted_mathtype
|
||||
or current_bl_socket.physical_type != dim.physical_type
|
||||
or current_bl_socket.mathtype != dim.mathtype
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
'Value': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Value,
|
||||
physical_type=pinned_physical_type,
|
||||
mathtype=wanted_mathtype,
|
||||
default_unit=pinned_unit,
|
||||
physical_type=dim.physical_type,
|
||||
mathtype=dim.mathtype,
|
||||
default_unit=dim.unit,
|
||||
),
|
||||
}
|
||||
|
||||
# Pin Dim by-Index: Synchronize Input Socket
|
||||
## -> The user will be given a simple integer socket.
|
||||
# Loose Sockets: Pin Dim by-Value
|
||||
## -> Works with discrete points / labelled integers.
|
||||
elif (
|
||||
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None
|
||||
props['operation'] is FilterOperation.PinIdx
|
||||
and dim_0 is not None
|
||||
and (info.has_idx_discrete(dim_0) or info.has_idx_labels(dim_0))
|
||||
):
|
||||
dim = dim_0
|
||||
current_bl_socket = self.loose_input_sockets.get('Axis')
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.active_kind != ct.FlowKind.Value
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
|
||||
or current_bl_socket.mathtype != spux.MathType.Integer
|
||||
|
@ -505,28 +480,26 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
}
|
||||
|
||||
# Set Dim: Synchronize Input Socket
|
||||
# Loose Sockets: Set Dim
|
||||
## -> The user must provide a (ℤ) -> ℝ array.
|
||||
## -> It must be of identical length to the replaced axis.
|
||||
elif (
|
||||
props['operation'] is FilterOperation.SetDim
|
||||
and props['dim_0'] is not None
|
||||
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
|
||||
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
|
||||
):
|
||||
# Deduce Axis Information
|
||||
elif props['operation'] is FilterOperation.SetDim and dim_0 is not None:
|
||||
dim = dim_0
|
||||
current_bl_socket = self.loose_input_sockets.get('Dim')
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.active_kind != ct.FlowKind.LazyValueFunc
|
||||
or current_bl_socket.mathtype != spux.MathType.Real
|
||||
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
|
||||
or current_bl_socket.active_kind != ct.FlowKind.Func
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
or current_bl_socket.mathtype != dim.mathtype
|
||||
or current_bl_socket.physical_type != dim.physical_type
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
'Dim': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.NonPhysical,
|
||||
active_kind=ct.FlowKind.Func,
|
||||
physical_type=dim.physical_type,
|
||||
mathtype=dim.mathtype,
|
||||
default_unit=dim.unit,
|
||||
show_func_ui=False,
|
||||
show_info_columns=True,
|
||||
)
|
||||
}
|
||||
|
@ -536,42 +509,37 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
self.loose_input_sockets = {}
|
||||
|
||||
####################
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
# - FlowKind.Value|Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.Func, ct.FlowKind.Info}},
|
||||
)
|
||||
def compute_lazy_value_func(self, props: dict, input_sockets: dict):
|
||||
def compute_lazy_func(self, props: dict, input_sockets: dict):
|
||||
operation = props['operation']
|
||||
lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
|
||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||
has_lazy_func = not ct.FlowSignal.check(lazy_func)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
|
||||
# Dimension(s)
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
slice_tuple = props['slice_tuple']
|
||||
if (
|
||||
has_lazy_value_func
|
||||
has_lazy_func
|
||||
and has_info
|
||||
and operation is not None
|
||||
and operation.are_dims_valid(info, dim_0, dim_1)
|
||||
):
|
||||
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None
|
||||
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None
|
||||
slice_tuple = (
|
||||
props['slice_tuple']
|
||||
if self.operation is FilterOperation.SliceIdx
|
||||
else None
|
||||
)
|
||||
axis_0 = info.dim_axis(dim_0) if dim_0 is not None else None
|
||||
axis_1 = info.dim_axis(dim_1) if dim_1 is not None else None
|
||||
|
||||
return lazy_value_func.compose_within(
|
||||
operation.jax_func(axis_0, axis_1, slice_tuple),
|
||||
return lazy_func.compose_within(
|
||||
operation.jax_func(axis_0, axis_1, slice_tuple=slice_tuple),
|
||||
enclosing_func_args=operation.func_args,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
@ -588,27 +556,26 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
'dim_1',
|
||||
'operation',
|
||||
'slice_tuple',
|
||||
'set_dim_symbol',
|
||||
'set_dim_active_unit',
|
||||
},
|
||||
input_sockets={'Expr', 'Dim'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.Info,
|
||||
'Dim': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params, ct.FlowKind.Info},
|
||||
'Dim': {ct.FlowKind.Func, ct.FlowKind.Params, ct.FlowKind.Info},
|
||||
},
|
||||
input_sockets_optional={'Dim': True},
|
||||
)
|
||||
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
|
||||
operation = props['operation']
|
||||
info = input_sockets['Expr']
|
||||
dim_coords = input_sockets['Dim'][ct.FlowKind.LazyValueFunc]
|
||||
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
|
||||
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
|
||||
dim_symbol = props['set_dim_symbol']
|
||||
dim_active_unit = props['set_dim_active_unit']
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_dim_coords = not ct.FlowSignal.check(dim_coords)
|
||||
|
||||
# Dim (Op.SetDim)
|
||||
dim_func = input_sockets['Dim'][ct.FlowKind.Func]
|
||||
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
|
||||
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
|
||||
|
||||
has_dim_func = not ct.FlowSignal.check(dim_func)
|
||||
has_dim_params = not ct.FlowSignal.check(dim_params)
|
||||
has_dim_info = not ct.FlowSignal.check(dim_info)
|
||||
|
||||
|
@ -619,44 +586,42 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
if has_info and operation is not None:
|
||||
# Set Dimension: Retrieve Array
|
||||
if props['operation'] is FilterOperation.SetDim:
|
||||
new_dim = (
|
||||
next(dim_info.dims.keys()) if len(dim_info.dims) >= 1 else None
|
||||
)
|
||||
|
||||
if (
|
||||
dim_0 is not None
|
||||
# Check Replaced Dimension
|
||||
and has_dim_coords
|
||||
and len(dim_coords.func_args) == 1
|
||||
and dim_coords.func_args[0] is spux.MathType.Integer
|
||||
and not dim_coords.func_kwargs
|
||||
and dim_coords.supports_jax
|
||||
# Check Params
|
||||
and has_dim_params
|
||||
and len(dim_params.func_args) == 1
|
||||
and not dim_params.func_kwargs
|
||||
# Check Info
|
||||
and new_dim is not None
|
||||
and has_dim_info
|
||||
and has_dim_params
|
||||
# Check New Dimension Index Array Sizing
|
||||
and len(dim_info.dims) == 1
|
||||
and dim_info.output.rows == 1
|
||||
and dim_info.output.cols == 1
|
||||
# Check Lack of Params Symbols
|
||||
and not dim_params.symbols
|
||||
# Check Expr Dim | New Dim Compatibility
|
||||
and info.has_idx_discrete(dim_0)
|
||||
and dim_info.has_idx_discrete(new_dim)
|
||||
and len(info.dims[dim_0]) == len(dim_info.dims[new_dim])
|
||||
):
|
||||
# Retrieve Dimension Coordinate Array
|
||||
## -> It must be strictly compatible.
|
||||
values = dim_coords.func_jax(int(dim_params.func_args[0]))
|
||||
if (
|
||||
len(values.shape) != 1
|
||||
or values.shape[0] != info.dim_lens[dim_0]
|
||||
):
|
||||
return ct.FlowSignal.FlowPending
|
||||
values = dim_func.realize(dim_params, spux.UNITS_SI)
|
||||
|
||||
# Transform Info w/Corrected Dimension
|
||||
## -> The existing dimension will be replaced.
|
||||
if dim_active_unit is not None:
|
||||
dim_unit = spux.unit_str_to_unit(dim_active_unit)
|
||||
else:
|
||||
dim_unit = None
|
||||
|
||||
new_dim_idx = ct.ArrayFlow(
|
||||
values=values,
|
||||
unit=dim_unit,
|
||||
)
|
||||
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)]
|
||||
unit=spux.convert_to_unit_system(
|
||||
dim_info.output.unit, spux.UNITS_SI
|
||||
),
|
||||
).rescale_to_unit(dim_info.output.unit)
|
||||
|
||||
replaced_dim = [dim_0, (dim_info.output.name, new_dim_idx)]
|
||||
return operation.transform_info(
|
||||
info, dim_0, dim_1, corrected_dim=corrected_dim
|
||||
info, dim_0, dim_1, replaced_dim=replaced_dim
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
|
||||
|
@ -702,7 +667,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
# Pin by-Value: Compute Nearest IDX
|
||||
## -> Presume a sorted index array to be able to use binary search.
|
||||
if props['operation'] is FilterOperation.Pin and has_pinned_value:
|
||||
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
|
||||
nearest_idx_to_value = info.dims[dim_0].nearest_idx_of(
|
||||
pinned_value, require_sorted=True
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import bpy
|
|||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
|
@ -153,40 +153,38 @@ class MapOperation(enum.StrEnum):
|
|||
# - Ops from Shape
|
||||
####################
|
||||
@staticmethod
|
||||
def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]:
|
||||
def by_expr_info(info: ct.InfoFlow) -> list[typ.Self]:
|
||||
## TODO: By info, not shape.
|
||||
## TODO: Check valid domains/mathtypes for some functions.
|
||||
MO = MapOperation
|
||||
element_ops = [
|
||||
MO.Real,
|
||||
MO.Imag,
|
||||
MO.Abs,
|
||||
MO.Sq,
|
||||
MO.Sqrt,
|
||||
MO.InvSqrt,
|
||||
MO.Cos,
|
||||
MO.Sin,
|
||||
MO.Tan,
|
||||
MO.Acos,
|
||||
MO.Asin,
|
||||
MO.Atan,
|
||||
MO.Sinc,
|
||||
]
|
||||
|
||||
match shape:
|
||||
case 'noshape':
|
||||
return []
|
||||
match (info.output.rows, info.output.cols):
|
||||
case (1, 1):
|
||||
return element_ops
|
||||
|
||||
# By Number
|
||||
case None:
|
||||
return [
|
||||
MO.Real,
|
||||
MO.Imag,
|
||||
MO.Abs,
|
||||
MO.Sq,
|
||||
MO.Sqrt,
|
||||
MO.InvSqrt,
|
||||
MO.Cos,
|
||||
MO.Sin,
|
||||
MO.Tan,
|
||||
MO.Acos,
|
||||
MO.Asin,
|
||||
MO.Atan,
|
||||
MO.Sinc,
|
||||
]
|
||||
case (_, 1):
|
||||
return [*element_ops, MO.Norm2]
|
||||
|
||||
match len(shape):
|
||||
# By Vector
|
||||
case 1:
|
||||
return [
|
||||
MO.Norm2,
|
||||
]
|
||||
# By Matrix
|
||||
case 2:
|
||||
case (rows, cols) if rows == cols:
|
||||
## TODO: Check hermitian/posdef for cholesky.
|
||||
## - Can we even do this with just the output symbol approach?
|
||||
return [
|
||||
*element_ops,
|
||||
MO.Det,
|
||||
MO.Cond,
|
||||
MO.NormFro,
|
||||
|
@ -201,6 +199,18 @@ class MapOperation(enum.StrEnum):
|
|||
MO.Svd,
|
||||
]
|
||||
|
||||
case (rows, cols):
|
||||
return [
|
||||
*element_ops,
|
||||
MO.Cond,
|
||||
MO.NormFro,
|
||||
MO.Rank,
|
||||
MO.SvdVals,
|
||||
MO.Inv,
|
||||
MO.Tra,
|
||||
MO.Svd,
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
####################
|
||||
|
@ -288,41 +298,76 @@ class MapOperation(enum.StrEnum):
|
|||
|
||||
def transform_info(self, info: ct.InfoFlow):
|
||||
MO = MapOperation
|
||||
|
||||
return {
|
||||
# By Number
|
||||
MO.Real: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||
MO.Imag: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||
MO.Abs: lambda: info.set_output_mathtype(spux.MathType.Real),
|
||||
MO.Real: lambda: info.update_output(mathtype=spux.MathType.Real),
|
||||
MO.Imag: lambda: info.update_output(mathtype=spux.MathType.Real),
|
||||
MO.Abs: lambda: info.update_output(mathtype=spux.MathType.Real),
|
||||
MO.Sq: lambda: info,
|
||||
MO.Sqrt: lambda: info,
|
||||
MO.InvSqrt: lambda: info,
|
||||
MO.Cos: lambda: info,
|
||||
MO.Sin: lambda: info,
|
||||
MO.Tan: lambda: info,
|
||||
MO.Acos: lambda: info,
|
||||
MO.Asin: lambda: info,
|
||||
MO.Atan: lambda: info,
|
||||
MO.Sinc: lambda: info,
|
||||
# By Vector
|
||||
MO.Norm2: lambda: info.collapse_output(
|
||||
collapsed_name=MO.to_name(self).replace('v', info.output_name),
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=info.output_unit,
|
||||
MO.Norm2: lambda: info.update_output(
|
||||
mathtype=spux.MathType.Real,
|
||||
rows=1,
|
||||
cols=1,
|
||||
# Interval
|
||||
interval_finite_re=(0, sim_symbols.float_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
# By Matrix
|
||||
MO.Det: lambda: info.collapse_output(
|
||||
collapsed_name=MO.to_name(self).replace('V', info.output_name),
|
||||
collapsed_mathtype=info.output_mathtype,
|
||||
collapsed_unit=info.output_unit,
|
||||
MO.Det: lambda: info.update_output(
|
||||
rows=1,
|
||||
cols=1,
|
||||
),
|
||||
MO.Cond: lambda: info.collapse_output(
|
||||
collapsed_name=MO.to_name(self).replace('V', info.output_name),
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=None,
|
||||
MO.Cond: lambda: info.update_output(
|
||||
mathtype=spux.MathType.Real,
|
||||
rows=1,
|
||||
cols=1,
|
||||
physical_type=spux.PhysicalType.NonPhysical,
|
||||
unit=None,
|
||||
),
|
||||
MO.NormFro: lambda: info.collapse_output(
|
||||
collapsed_name=MO.to_name(self).replace('V', info.output_name),
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=info.output_unit,
|
||||
MO.NormFro: lambda: info.update_output(
|
||||
mathtype=spux.MathType.Real,
|
||||
rows=1,
|
||||
cols=1,
|
||||
# Interval
|
||||
interval_finite_re=(0, sim_symbols.float_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
MO.Rank: lambda: info.collapse_output(
|
||||
collapsed_name=MO.to_name(self).replace('V', info.output_name),
|
||||
collapsed_mathtype=spux.MathType.Integer,
|
||||
collapsed_unit=None,
|
||||
MO.Rank: lambda: info.update_output(
|
||||
mathtype=spux.MathType.Integer,
|
||||
rows=1,
|
||||
cols=1,
|
||||
physical_type=spux.PhysicalType.NonPhysical,
|
||||
unit=None,
|
||||
# Interval
|
||||
interval_finite_re=(0, sim_symbols.int_max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
## TODO: Matrix -> Vec
|
||||
## TODO: Matrix -> Matrices
|
||||
}.get(self, lambda: info)()
|
||||
# Matrix -> Vector ## TODO: ALL OF THESE
|
||||
MO.Diag: lambda: info,
|
||||
MO.EigVals: lambda: info,
|
||||
MO.SvdVals: lambda: info,
|
||||
# Matrix -> Matrix ## TODO: ALL OF THESE
|
||||
MO.Inv: lambda: info,
|
||||
MO.Tra: lambda: info,
|
||||
# Matrix -> Matrices ## TODO: ALL OF THESE
|
||||
MO.Qr: lambda: info,
|
||||
MO.Chol: lambda: info,
|
||||
MO.Svd: lambda: info,
|
||||
}[self]()
|
||||
|
||||
|
||||
####################
|
||||
|
@ -401,7 +446,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
The name and type of the available symbol is clearly shown, and most valid `sympy` expressions that you would expect to work, should work.
|
||||
|
||||
Use of expressions generally imposes no performance penalty: Just like the baked-in operations, it is compiled to a high-performance `jax` function.
|
||||
Thus, it participates in the `ct.FlowKind.LazyValueFunc` composition chain.
|
||||
Thus, it participates in the `ct.FlowKind.Func` composition chain.
|
||||
|
||||
|
||||
Attributes:
|
||||
|
@ -412,10 +457,10 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
bl_label = 'Map Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -435,29 +480,26 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
|
||||
if has_info and not info_pending:
|
||||
self.expr_output_shape = bl_cache.Signal.InvalidateCache
|
||||
self.expr_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_output_shape(self) -> ct.InfoFlow | None:
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info.output_shape
|
||||
|
||||
return 'noshape'
|
||||
return info
|
||||
return None
|
||||
|
||||
operation: MapOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_output_shape'},
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_output_shape != 'noshape':
|
||||
if self.info is not None:
|
||||
return [
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(
|
||||
MapOperation.by_element_shape(self.expr_output_shape)
|
||||
)
|
||||
for i, operation in enumerate(MapOperation.by_expr_info(self.expr_info))
|
||||
]
|
||||
return []
|
||||
|
||||
|
@ -474,7 +516,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -495,18 +537,19 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.LazyValueFunc,
|
||||
'Expr': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_func(
|
||||
self, props, input_sockets
|
||||
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
|
||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
|
@ -520,7 +563,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Info|Params
|
||||
# - FlowKind.Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -540,6 +583,9 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
|
|
|
@ -261,12 +261,12 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
bl_label = 'Operate Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr L': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
'Expr R': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc, show_info_columns=True
|
||||
active_kind=ct.FlowKind.Func, show_info_columns=True
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -344,7 +344,7 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
# - FlowKind.Value|Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -373,12 +373,12 @@ class OperateMathNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr L', 'Expr R'},
|
||||
input_socket_kinds={
|
||||
'Expr L': ct.FlowKind.LazyValueFunc,
|
||||
'Expr R': ct.FlowKind.LazyValueFunc,
|
||||
'Expr L': ct.FlowKind.Func,
|
||||
'Expr R': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compose_func(self, props: dict, input_sockets: dict):
|
||||
|
|
|
@ -97,7 +97,7 @@ class ReduceMathNode(base.MaxwellSimNode):
|
|||
'Data',
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data', 'Axis', 'Reducer'},
|
||||
input_socket_kinds={'Reducer': ct.FlowKind.LazyValueFunc},
|
||||
input_socket_kinds={'Reducer': ct.FlowKind.Func},
|
||||
input_sockets_optional={'Reducer': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
|
|
|
@ -107,32 +107,31 @@ class TransformOperation(enum.StrEnum):
|
|||
|
||||
# Covariant Transform
|
||||
## Freq <-> VacWL
|
||||
for dim_name in info.dim_names:
|
||||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
for dim in info.dims:
|
||||
if dim.physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.FreqToVacWL)
|
||||
|
||||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
if dim.physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.VacWLToFreq)
|
||||
|
||||
# Fold
|
||||
## (Last) Int Dim (=2) to Complex
|
||||
if len(info.dim_names) >= 1:
|
||||
last_dim_name = info.dim_names[-1]
|
||||
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
|
||||
if len(info.dims) >= 1:
|
||||
if not info.has_idx_labels(info.last_dim) and len(info.last_dim) == 2: # noqa: PLR2004
|
||||
operations.append(TO.IntDimToComplex)
|
||||
|
||||
## To Vector
|
||||
if len(info.dim_names) >= 1:
|
||||
if len(info.dims) >= 1:
|
||||
operations.append(TO.DimToVec)
|
||||
|
||||
## To Matrix
|
||||
if len(info.dim_names) >= 2: # noqa: PLR2004
|
||||
if len(info.dims) >= 2: # noqa: PLR2004
|
||||
operations.append(TO.DimsToMat)
|
||||
|
||||
# Fourier
|
||||
## 1D Fourier
|
||||
if info.dim_names:
|
||||
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
|
||||
if info.dims:
|
||||
last_physical_type = info.last_dim.physical_type
|
||||
if last_physical_type == spux.PhysicalType.Time:
|
||||
operations.append(TO.FFT1D)
|
||||
if last_physical_type == spux.PhysicalType.Freq:
|
||||
|
@ -188,15 +187,15 @@ class TransformOperation(enum.StrEnum):
|
|||
unit: spux.Unit | None = None,
|
||||
) -> ct.InfoFlow | None:
|
||||
TO = TransformOperation
|
||||
if not info.dim_names:
|
||||
if not info.dims:
|
||||
return None
|
||||
return {
|
||||
# Index
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda: info.replace_dim(
|
||||
(f_dim := info.dim_names[-1]),
|
||||
(f_dim := info.last_dim),
|
||||
[
|
||||
'wl',
|
||||
info.dim_idx[f_dim].rescale(
|
||||
sim_symbols.wl(spu.nanometer),
|
||||
info.dims[f_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spu.nanometer,
|
||||
|
@ -204,10 +203,10 @@ class TransformOperation(enum.StrEnum):
|
|||
],
|
||||
),
|
||||
TO.VacWLToFreq: lambda: info.replace_dim(
|
||||
(wl_dim := info.dim_names[-1]),
|
||||
(wl_dim := info.last_dim),
|
||||
[
|
||||
'f',
|
||||
info.dim_idx[wl_dim].rescale(
|
||||
sim_symbols.freq(spux.THz),
|
||||
info.dims[wl_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spux.THz,
|
||||
|
@ -215,26 +214,24 @@ class TransformOperation(enum.StrEnum):
|
|||
],
|
||||
),
|
||||
# Fold
|
||||
TO.IntDimToComplex: lambda: info.delete_dimension(
|
||||
info.dim_names[-1]
|
||||
).set_output_mathtype(spux.MathType.Complex),
|
||||
TO.DimToVec: lambda: info.shift_last_input,
|
||||
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input,
|
||||
TO.IntDimToComplex: lambda: info.delete_dim(info.last_dim).update_output(
|
||||
mathtype=spux.MathType.Complex
|
||||
),
|
||||
TO.DimToVec: lambda: info.fold_last_input(),
|
||||
TO.DimsToMat: lambda: info.fold_last_input().fold_last_input(),
|
||||
# Fourier
|
||||
TO.FFT1D: lambda: info.replace_dim(
|
||||
info.dim_names[-1],
|
||||
info.last_dim,
|
||||
[
|
||||
'f',
|
||||
ct.LazyArrayRangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz),
|
||||
sim_symbols.freq(spux.THz),
|
||||
None,
|
||||
],
|
||||
),
|
||||
TO.InvFFT1D: info.replace_dim(
|
||||
info.dim_names[-1],
|
||||
info.last_dim,
|
||||
[
|
||||
't',
|
||||
ct.LazyArrayRangeFlow(
|
||||
start=0, stop=sp.oo, steps=0, unit=spu.second
|
||||
),
|
||||
sim_symbols.t(spu.second),
|
||||
None,
|
||||
],
|
||||
),
|
||||
}.get(self, lambda: info)()
|
||||
|
@ -260,10 +257,10 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
bl_label = 'Transform Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -325,7 +322,7 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Compute: LazyValueFunc / Array
|
||||
# - Compute: Func / Array
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -348,16 +345,14 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.LazyValueFunc,
|
||||
'Expr': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_func(
|
||||
self, props, input_sockets
|
||||
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
|
||||
def compute_func(self, props, input_sockets) -> ct.FuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@ class VizMode(enum.StrEnum):
|
|||
**NOTE**: >1D output dimensions currently have no viz.
|
||||
|
||||
Plots for `() -> ℝ`:
|
||||
- Hist1D: Bin-summed distribution.
|
||||
- BoxPlot1D: Box-plot describing the distribution.
|
||||
|
||||
Plots for `(ℤ) -> ℝ`:
|
||||
|
@ -61,7 +60,6 @@ class VizMode(enum.StrEnum):
|
|||
- Heatmap3D: Colormapped field with value at each voxel.
|
||||
"""
|
||||
|
||||
Hist1D = enum.auto()
|
||||
BoxPlot1D = enum.auto()
|
||||
|
||||
Curve2D = enum.auto()
|
||||
|
@ -78,42 +76,38 @@ class VizMode(enum.StrEnum):
|
|||
|
||||
@staticmethod
|
||||
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||
EMPTY = ()
|
||||
Z = spux.MathType.Integer
|
||||
R = spux.MathType.Real
|
||||
VM = VizMode
|
||||
|
||||
valid_viz_modes = {
|
||||
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
|
||||
((Z), (None, R)): [
|
||||
VM.Hist1D,
|
||||
return {
|
||||
((Z), (1, 1, R)): [
|
||||
VM.BoxPlot1D,
|
||||
],
|
||||
((R,), (None, R)): [
|
||||
((R,), (1, 1, R)): [
|
||||
VM.Curve2D,
|
||||
VM.Points2D,
|
||||
VM.Bar,
|
||||
],
|
||||
((R, Z), (None, R)): [
|
||||
((R, Z), (1, 1, R)): [
|
||||
VM.Curves2D,
|
||||
VM.FilledCurves2D,
|
||||
],
|
||||
((R, R), (None, R)): [
|
||||
((R, R), (1, 1, R)): [
|
||||
VM.Heatmap2D,
|
||||
],
|
||||
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
|
||||
((R, R, R), (1, 1, R)): [
|
||||
VM.SqueezedHeatmap2D,
|
||||
VM.Heatmap3D,
|
||||
],
|
||||
}.get(
|
||||
(
|
||||
tuple(info.dim_mathtypes.values()),
|
||||
(info.output_shape, info.output_mathtype),
|
||||
)
|
||||
tuple([dim.mathtype for dim in info.dims.values()]),
|
||||
(info.output.rows, info.output.cols, info.output.mathtype),
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
if valid_viz_modes is None:
|
||||
return []
|
||||
|
||||
return valid_viz_modes
|
||||
|
||||
@staticmethod
|
||||
def to_plotter(
|
||||
value: typ.Self,
|
||||
|
@ -121,7 +115,6 @@ class VizMode(enum.StrEnum):
|
|||
[jtyp.Float32[jtyp.Array, '...'], ct.InfoFlow, mpl_ax.Axis], None
|
||||
]:
|
||||
return {
|
||||
VizMode.Hist1D: image_ops.plot_hist_1d,
|
||||
VizMode.BoxPlot1D: image_ops.plot_box_plot_1d,
|
||||
VizMode.Curve2D: image_ops.plot_curve_2d,
|
||||
VizMode.Points2D: image_ops.plot_points_2d,
|
||||
|
@ -136,7 +129,6 @@ class VizMode(enum.StrEnum):
|
|||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
VizMode.Hist1D: 'Histogram',
|
||||
VizMode.BoxPlot1D: 'Box Plot',
|
||||
VizMode.Curve2D: 'Curve',
|
||||
VizMode.Points2D: 'Points',
|
||||
|
@ -164,7 +156,6 @@ class VizTarget(enum.StrEnum):
|
|||
@staticmethod
|
||||
def valid_targets_for(viz_mode: VizMode) -> list[typ.Self] | None:
|
||||
return {
|
||||
VizMode.Hist1D: [VizTarget.Plot2D],
|
||||
VizMode.BoxPlot1D: [VizTarget.Plot2D],
|
||||
VizMode.Curve2D: [VizTarget.Plot2D],
|
||||
VizMode.Points2D: [VizTarget.Plot2D],
|
||||
|
@ -209,7 +200,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
####################
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc,
|
||||
active_kind=ct.FlowKind.Func,
|
||||
default_symbols=[sim_symbols.x],
|
||||
default_value=2 * sim_symbols.x.sp_symbol,
|
||||
),
|
||||
|
@ -333,35 +324,20 @@ class VizNode(base.MaxwellSimNode):
|
|||
## -> This happens if Params contains not-yet-realized symbols.
|
||||
if has_info and has_params and params.symbols:
|
||||
if set(self.loose_input_sockets) != {
|
||||
sym.name for sym in params.symbols if sym.name in info.dim_names
|
||||
dim.name for dim in params.symbols if dim in info.dims
|
||||
}:
|
||||
self.loose_input_sockets = {
|
||||
sym.name: sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
size=spux.NumberSize1D.Scalar,
|
||||
mathtype=info.dim_mathtypes[sym.name],
|
||||
physical_type=info.dim_physical_types[sym.name],
|
||||
default_min=(
|
||||
info.dim_idx[sym.name].start
|
||||
if not sp.S(info.dim_idx[sym.name].start).is_infinite
|
||||
else sp.S(0)
|
||||
),
|
||||
default_max=(
|
||||
info.dim_idx[sym.name].start
|
||||
if not sp.S(info.dim_idx[sym.name].stop).is_infinite
|
||||
else sp.S(1)
|
||||
),
|
||||
default_steps=50,
|
||||
)
|
||||
for sym in params.sorted_symbols
|
||||
if sym.name in info.dim_names
|
||||
dim_name: sockets.ExprSocketDef(**expr_info)
|
||||
for dim_name, expr_info in params.sym_expr_infos(
|
||||
info, use_range=True
|
||||
).items()
|
||||
}
|
||||
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
#####################
|
||||
## - Plotting
|
||||
## - FlowKind.Value
|
||||
#####################
|
||||
@events.computes_output_socket(
|
||||
'Preview',
|
||||
|
@ -370,37 +346,38 @@ class VizNode(base.MaxwellSimNode):
|
|||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
},
|
||||
all_loose_input_sockets=True,
|
||||
)
|
||||
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
|
||||
"""Needed for the plot to regenerate in the viewer."""
|
||||
return ct.FlowSignal.NoFlow
|
||||
|
||||
#####################
|
||||
## - On Show Plot
|
||||
#####################
|
||||
@events.on_show_plot(
|
||||
managed_objs={'plot'},
|
||||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
'Expr': {ct.FlowKind.Func, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
all_loose_input_sockets=True,
|
||||
stop_propagation=True,
|
||||
)
|
||||
def on_show_plot(
|
||||
self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems
|
||||
):
|
||||
self, managed_objs, props, input_sockets, loose_input_sockets
|
||||
) -> None:
|
||||
# Retrieve Inputs
|
||||
lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
||||
lazy_func = input_sockets['Expr'][ct.FlowKind.Func]
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
# Invalid Mode | Target
|
||||
## -> To limit branching, return now if things aren't right.
|
||||
if (
|
||||
not has_info
|
||||
or not has_params
|
||||
|
@ -409,54 +386,43 @@ class VizNode(base.MaxwellSimNode):
|
|||
):
|
||||
return
|
||||
|
||||
# Compute LazyArrayRanges for Symbols from Loose Sockets
|
||||
## -> These are the concrete values of the symbol for plotting.
|
||||
# Compute Ranges for Symbols from Loose Sockets
|
||||
## -> In a quite nice turn of events, all this is cached lookups.
|
||||
## -> ...Unless something changed, in which case, well. It changed.
|
||||
symbol_values = {
|
||||
sym: (
|
||||
loose_input_sockets[sym.name]
|
||||
.realize_array.rescale_to_unit(info.dim_units[sym.name])
|
||||
.values
|
||||
symbol_array_values = {
|
||||
sim_syms: (
|
||||
loose_input_sockets[sim_syms]
|
||||
.rescale_to_unit(sim_syms.unit)
|
||||
.realize_array
|
||||
)
|
||||
for sym in params.sorted_symbols
|
||||
for sim_syms in params.sorted_symbols
|
||||
}
|
||||
|
||||
# Realize LazyValueFunc w/Symbolic Values, Unit System
|
||||
## -> This gives us the actual plot data!
|
||||
data = lazy_value_func.func_jax(
|
||||
*params.scaled_func_args(
|
||||
unit_systems['BlenderUnits'], symbol_values=symbol_values
|
||||
),
|
||||
**params.scaled_func_kwargs(
|
||||
unit_systems['BlenderUnits'], symbol_values=symbol_values
|
||||
),
|
||||
)
|
||||
data = lazy_func.realize(params, symbol_values=symbol_array_values)
|
||||
|
||||
# Replace InfoFlow Indices w/Realized Symbolic Ranges
|
||||
## -> This ensures correct axis scaling.
|
||||
if params.symbols:
|
||||
info = info.rescale_dim_idxs(loose_input_sockets)
|
||||
info = info.replace_dims(symbol_array_values)
|
||||
|
||||
# Visualize by-Target
|
||||
if props['viz_target'] == VizTarget.Plot2D:
|
||||
managed_objs['plot'].mpl_plot_to_image(
|
||||
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
|
||||
bl_select=True,
|
||||
)
|
||||
match props['viz_target']:
|
||||
case VizTarget.Plot2D:
|
||||
managed_objs['plot'].mpl_plot_to_image(
|
||||
lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax),
|
||||
bl_select=True,
|
||||
)
|
||||
|
||||
if props['viz_target'] == VizTarget.Pixels:
|
||||
managed_objs['plot'].map_2d_to_image(
|
||||
data,
|
||||
colormap=props['colormap'],
|
||||
bl_select=True,
|
||||
)
|
||||
case VizTarget.Pixels:
|
||||
managed_objs['plot'].map_2d_to_image(
|
||||
data,
|
||||
colormap=props['colormap'],
|
||||
bl_select=True,
|
||||
)
|
||||
|
||||
if props['viz_target'] == VizTarget.PixelsPlane:
|
||||
raise NotImplementedError
|
||||
case VizTarget.PixelsPlane:
|
||||
raise NotImplementedError
|
||||
|
||||
if props['viz_target'] == VizTarget.Voxels:
|
||||
raise NotImplementedError
|
||||
case VizTarget.Voxels:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -67,7 +67,7 @@ ManagedObjName: typ.TypeAlias = str
|
|||
PropName: typ.TypeAlias = str
|
||||
|
||||
|
||||
def event_decorator(
|
||||
def event_decorator( # noqa: PLR0913
|
||||
event: ct.FlowEvent,
|
||||
callback_info: EventCallbackInfo | None,
|
||||
stop_propagation: bool = False,
|
||||
|
@ -91,31 +91,42 @@ def event_decorator(
|
|||
scale_input_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
|
||||
scale_output_sockets: dict[ct.SocketName, UnitSystemID] = MappingProxyType({}),
|
||||
):
|
||||
"""Returns a decorator for a method of `MaxwellSimNode`, declaring it as able respond to events passing through a node.
|
||||
"""Low-level decorator declaring a special "event method" of `MaxwellSimNode`, which is able to handle `ct.FlowEvent`s passing through.
|
||||
|
||||
Should generally be used via a high-level decorator such as `on_value_changed`.
|
||||
|
||||
For more about how event methods are actually registered and run, please refer to the documentation of `MaxwellSimNode`.
|
||||
|
||||
Parameters:
|
||||
event: A name describing which event the decorator should respond to.
|
||||
Set to `return_method.event`
|
||||
callback_info: A dictionary that provides the caller with additional per-`event` information.
|
||||
This might include parameters to help select the most appropriate method(s) to respond to an event with, or events to take after running the callback.
|
||||
props: Set of `props` to compute, then pass to the decorated method.
|
||||
stop_propagation: Whether or stop propagating the event through the graph after encountering this method.
|
||||
Other methods defined on the same node will still run.
|
||||
managed_objs: Set of `managed_objs` to retrieve, then pass to the decorated method.
|
||||
props: Set of `props` to compute, then pass to the decorated method.
|
||||
input_sockets: Set of `input_sockets` to compute, then pass to the decorated method.
|
||||
input_sockets_optional: Whether an input socket is required to exist.
|
||||
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
|
||||
input_socket_kinds: The `ct.FlowKind` to compute per-input-socket.
|
||||
If an input socket isn't specified, it defaults to `ct.FlowKind.Value`.
|
||||
output_sockets: Set of `output_sockets` to compute, then pass to the decorated method.
|
||||
output_sockets_optional: Whether an output socket is required to exist.
|
||||
When True, lack of socket will produce `ct.FlowSignal.NoFlow`, instead of throwing an error.
|
||||
output_socket_kinds: The `ct.FlowKind` to compute per-output-socket.
|
||||
If an output socket isn't specified, it defaults to `ct.FlowKind.Value`.
|
||||
all_loose_input_sockets: Whether to compute all loose input sockets and pass them to the decorated method.
|
||||
Used when the names of the loose input sockets are unknown, but all of their values are needed.
|
||||
all_loose_output_sockets: Whether to compute all loose output sockets and pass them to the decorated method.
|
||||
Used when the names of the loose output sockets are unknown, but all of their values are needed.
|
||||
unit_systems: String identifiers under which to load a unit system, made available to the method.
|
||||
scale_input_sockets: A mapping of input sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
|
||||
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
|
||||
scale_output_sockets: A mapping of output sockets to unit system string idenfiers, which causes the output of that input socket to be scaled to the given unit system.
|
||||
This greatly simplifies the conformance of particular sockets to particular unit systems, when the socket value must be used in a unit-unaware manner.
|
||||
|
||||
Returns:
|
||||
A decorator, which can be applied to a method of `MaxwellSimNode`.
|
||||
When a `MaxwellSimNode` subclass initializes, such a decorated method will be picked up on.
|
||||
|
||||
When `event` passes through the node, then `callback_info` is used to determine
|
||||
A decorator, which can be applied to a method of `MaxwellSimNode` to make it an "event method".
|
||||
"""
|
||||
req_params = (
|
||||
{'self'}
|
||||
|
@ -375,7 +386,6 @@ def on_value_changed(
|
|||
)
|
||||
|
||||
|
||||
## TODO: Change name to 'on_output_requested'
|
||||
def computes_output_socket(
|
||||
output_socket_name: ct.SocketName | None,
|
||||
kind: ct.FlowKind = ct.FlowKind.Value,
|
||||
|
|
|
@ -29,12 +29,12 @@ class ExprConstantNode(base.MaxwellSimNode):
|
|||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc,
|
||||
active_kind=ct.FlowKind.Func,
|
||||
),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc,
|
||||
active_kind=ct.FlowKind.Func,
|
||||
show_info_columns=True,
|
||||
),
|
||||
}
|
||||
|
@ -58,12 +58,12 @@ class ExprConstantNode(base.MaxwellSimNode):
|
|||
@events.computes_output_socket(
|
||||
# Trigger
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Func},
|
||||
)
|
||||
def compute_lazy_value_func(self, input_sockets: dict) -> typ.Any:
|
||||
def compute_lazy_func(self, input_sockets: dict) -> typ.Any:
|
||||
return input_sockets['Expr']
|
||||
|
||||
####################
|
||||
|
|
|
@ -19,14 +19,10 @@ import typing as typ
|
|||
from pathlib import Path
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sympy as sp
|
||||
import tidy3d as td
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
|
@ -35,112 +31,6 @@ from ... import base, events
|
|||
|
||||
log = logger.get(__name__)
|
||||
|
||||
####################
|
||||
# - Data File Extensions
|
||||
####################
|
||||
_DATA_FILE_EXTS = {
|
||||
'.txt',
|
||||
'.txt.gz',
|
||||
'.csv',
|
||||
'.npy',
|
||||
}
|
||||
|
||||
|
||||
class DataFileExt(enum.StrEnum):
|
||||
Txt = enum.auto()
|
||||
TxtGz = enum.auto()
|
||||
Csv = enum.auto()
|
||||
Npy = enum.auto()
|
||||
|
||||
####################
|
||||
# - Enum Elements
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
return DataFileExt(v).extension
|
||||
|
||||
@staticmethod
|
||||
def to_icon(v: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def extension(self) -> str:
|
||||
"""Map to the actual string extension."""
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: '.txt',
|
||||
E.TxtGz: '.txt.gz',
|
||||
E.Csv: '.csv',
|
||||
E.Npy: '.npy',
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def loader(self) -> typ.Callable[[Path], jtyp.Shaped[jtyp.Array, '...']]:
|
||||
def load_txt(path: Path):
|
||||
return jnp.asarray(np.loadtxt(path))
|
||||
|
||||
def load_csv(path: Path):
|
||||
return jnp.asarray(pd.read_csv(path).values)
|
||||
|
||||
def load_npy(path: Path):
|
||||
return jnp.load(path)
|
||||
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: load_txt,
|
||||
E.TxtGz: load_txt,
|
||||
E.Csv: load_csv,
|
||||
E.Npy: load_npy,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def loader_is_jax_compatible(self) -> bool:
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: True,
|
||||
E.TxtGz: True,
|
||||
E.Csv: False,
|
||||
E.Npy: True,
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Creation
|
||||
####################
|
||||
@staticmethod
|
||||
def from_ext(ext: str) -> typ.Self | None:
|
||||
return {
|
||||
_ext: _data_file_ext
|
||||
for _data_file_ext, _ext in {
|
||||
k: k.extension for k in list(DataFileExt)
|
||||
}.items()
|
||||
}.get(ext)
|
||||
|
||||
@staticmethod
|
||||
def from_path(path: Path) -> typ.Self | None:
|
||||
if DataFileExt.is_path_compatible(path):
|
||||
data_file_ext = DataFileExt.from_ext(''.join(path.suffixes))
|
||||
if data_file_ext is not None:
|
||||
return data_file_ext
|
||||
|
||||
msg = f'DataFileExt: Path "{path}" is compatible, but could not find valid extension'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Compatibility
|
||||
####################
|
||||
@staticmethod
|
||||
def is_ext_compatible(ext: str):
|
||||
return ext in _DATA_FILE_EXTS
|
||||
|
||||
@staticmethod
|
||||
def is_path_compatible(path: Path):
|
||||
return path.is_file() and DataFileExt.is_ext_compatible(''.join(path.suffixes))
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
|
@ -153,7 +43,7 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
'File Path': sockets.FilePathSocketDef(),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
####################
|
||||
|
@ -168,10 +58,6 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
|
||||
|
||||
has_file_path = ct.FlowSignal.check_single(
|
||||
input_sockets['File Path'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
if has_file_path:
|
||||
self.file_path = bl_cache.Signal.InvalidateCache
|
||||
|
||||
|
@ -188,10 +74,10 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
return None
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'file_path'})
|
||||
def data_file_ext(self) -> DataFileExt | None:
|
||||
def data_file_format(self) -> ct.DataFileFormat | None:
|
||||
"""Retrieve the file extension by concatenating all suffixes."""
|
||||
if self.file_path is not None:
|
||||
return DataFileExt.from_path(self.file_path)
|
||||
return ct.DataFileFormat.from_path(self.file_path)
|
||||
return None
|
||||
|
||||
####################
|
||||
|
@ -201,11 +87,93 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
"""Retrieve the output expression's `InfoFlow`."""
|
||||
info = self.compute_output('Expr', kind=ct.FlowKind.Info)
|
||||
has_info = not ct.FlowKind.check(info)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Info Guides
|
||||
####################
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(sim_symbols.SimSymbolName)
|
||||
output_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
output_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
output_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
||||
cb_depends_on={'output_physical_type'},
|
||||
)
|
||||
|
||||
dim_0_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerA
|
||||
)
|
||||
dim_0_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_0_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_0_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_0_physical_type),
|
||||
cb_depends_on={'dim_0_physical_type'},
|
||||
)
|
||||
|
||||
dim_1_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerB
|
||||
)
|
||||
dim_1_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_1_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_1_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_1_physical_type),
|
||||
cb_depends_on={'dim_1_physical_type'},
|
||||
)
|
||||
|
||||
dim_2_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerC
|
||||
)
|
||||
dim_2_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_2_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_2_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_2_physical_type),
|
||||
cb_depends_on={'dim_2_physical_type'},
|
||||
)
|
||||
|
||||
dim_3_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.LowerD
|
||||
)
|
||||
dim_3_mathtype: sim_symbols.MathType = bl_cache.BLField(sim_symbols.Real)
|
||||
dim_3_physical_type: spux.PhysicalType = bl_cache.BLField(
|
||||
spux.PhysicalType.NonPhysical
|
||||
)
|
||||
dim_3_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_units(self.dim_3_physical_type),
|
||||
cb_depends_on={'dim_3_physical_type'},
|
||||
)
|
||||
|
||||
def search_units(self, physical_type: spux.PhysicalType) -> list[ct.BLEnumElement]:
|
||||
if physical_type is not spux.PhysicalType.NonPhysical:
|
||||
return [
|
||||
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
|
||||
for i, unit in enumerate(physical_type.valid_units)
|
||||
]
|
||||
return []
|
||||
|
||||
def dim(self, i: int):
|
||||
dim_name = getattr(self, f'dim_{i}_name')
|
||||
dim_mathtype = getattr(self, f'dim_{i}_mathtype')
|
||||
dim_physical_type = getattr(self, f'dim_{i}_physical_type')
|
||||
dim_unit = getattr(self, f'dim_{i}_unit')
|
||||
|
||||
return sim_symbols.SimSymbol(
|
||||
sym_name=dim_name,
|
||||
mathtype=dim_mathtype,
|
||||
physical_type=dim_physical_type,
|
||||
unit=spux.unit_str_to_unit(dim_unit),
|
||||
)
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
|
@ -216,13 +184,13 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
Called by Blender to determine the text to place in the node's header.
|
||||
"""
|
||||
if self.file_path is not None:
|
||||
return 'Load File: ' + self.file_path.name
|
||||
return 'Load: ' + self.file_path.name
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
"""Show information about the loaded file."""
|
||||
if self.data_file_ext is not None:
|
||||
if self.data_file_format is not None:
|
||||
box = layout.box()
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
|
@ -233,24 +201,27 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
row.label(text=self.file_path.name)
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
pass
|
||||
"""Draw loaded properties."""
|
||||
for i in range(len(self.expr_info.dims)):
|
||||
col = layout.column(align=True)
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=f'Load Dim {i}')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields[f'dim_{i}_name'], text='')
|
||||
row.prop(self, self.blfields[f'dim_{i}_mathtype'], text='')
|
||||
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields[f'dim_{i}_physical_type'], text='')
|
||||
row.prop(self, self.blfields[f'dim_{i}_unit'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='File Path',
|
||||
input_sockets={'File Path'},
|
||||
)
|
||||
def on_file_changed(self, input_sockets) -> None:
|
||||
pass
|
||||
|
||||
####################
|
||||
# - FlowKind.Array|LazyValueFunc
|
||||
# - FlowKind.Array|Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
kind=ct.FlowKind.Func,
|
||||
input_sockets={'File Path'},
|
||||
)
|
||||
def compute_func(self, input_sockets: dict) -> td.Simulation:
|
||||
|
@ -264,20 +235,20 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
|
||||
|
||||
if has_file_path:
|
||||
data_file_ext = DataFileExt.from_path(file_path)
|
||||
if data_file_ext is not None:
|
||||
data_file_format = ct.DataFileFormat.from_path(file_path)
|
||||
if data_file_format is not None:
|
||||
# Jax Compatibility: Lazy Data Loading
|
||||
## -> Delay loading of data from file as long as we can.
|
||||
if data_file_ext.loader_is_jax_compatible:
|
||||
return ct.LazyValueFuncFlow(
|
||||
func=lambda: data_file_ext.loader(file_path),
|
||||
if data_file_format.loader_is_jax_compatible:
|
||||
return ct.FuncFlow(
|
||||
func=lambda: data_file_format.loader(file_path),
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
# No Jax Compatibility: Eager Data Loading
|
||||
## -> Load the data now and bind it.
|
||||
data = data_file_ext.loader(file_path)
|
||||
return ct.LazyValueFuncFlow(func=lambda: data, supports_jax=True)
|
||||
data = data_file_format.loader(file_path)
|
||||
return ct.FuncFlow(func=lambda: data, supports_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
@ -299,10 +270,12 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
# Loaded
|
||||
props={'output_name', 'output_physical_type', 'output_unit'},
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.Func},
|
||||
)
|
||||
def compute_info(self, output_sockets) -> ct.InfoFlow:
|
||||
def compute_info(self, props, output_sockets) -> ct.InfoFlow:
|
||||
"""Declare an `InfoFlow` based on the data shape.
|
||||
|
||||
This currently requires computing the data.
|
||||
|
@ -321,26 +294,24 @@ class DataFileImporterNode(base.MaxwellSimNode):
|
|||
# Deduce Dimensionality
|
||||
_shape = data.shape
|
||||
shape = _shape if _shape is not None else ()
|
||||
dim_names = [f'a{i}' for i in range(len(shape))]
|
||||
dim_syms = [self.dim(i) for i in range(len(shape))]
|
||||
|
||||
# Return InfoFlow
|
||||
## -> TODO: How to interpret the data should be user-defined.
|
||||
## -> -- This may require those nice dynamic symbols.
|
||||
return ct.InfoFlow(
|
||||
dim_names=dim_names, ## TODO: User
|
||||
dim_idx={
|
||||
dim_name: ct.LazyArrayRangeFlow(
|
||||
start=sp.S(0), ## TODO: User
|
||||
stop=sp.S(shape[i] - 1), ## TODO: User
|
||||
steps=shape[dim_names.index(dim_name)],
|
||||
unit=None, ## TODO: User
|
||||
dims={
|
||||
dim_sym: ct.RangeFlow(
|
||||
start=sp.S(0),
|
||||
stop=sp.S(shape[i] - 1),
|
||||
steps=shape[i],
|
||||
unit=self.dim(i).unit,
|
||||
)
|
||||
for i, dim_name in enumerate(dim_names)
|
||||
for i, dim_sym in enumerate(dim_syms)
|
||||
},
|
||||
output_name='_',
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real, ## TODO: User
|
||||
output_unit=None, ## TODO: User
|
||||
output=sim_symbols.SimSymbol(
|
||||
sym_name=props['output_name'],
|
||||
mathtype=props['output_mathtype'],
|
||||
physical_type=props['output_physical_type'],
|
||||
),
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
|
|
@ -74,11 +74,11 @@ class SceneNode(base.MaxwellSimNode):
|
|||
return bpy.context.scene.frame_current
|
||||
|
||||
@property
|
||||
def scene_frame_range(self) -> ct.LazyArrayRangeFlow:
|
||||
def scene_frame_range(self) -> ct.RangeFlow:
|
||||
"""Retrieve the current start/end frame of the scene, with `steps` corresponding to single-frame steps."""
|
||||
frame_start = bpy.context.scene.frame_start
|
||||
frame_stop = bpy.context.scene.frame_end
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=frame_start,
|
||||
stop=frame_stop,
|
||||
steps=frame_stop - frame_start + 1,
|
||||
|
|
|
@ -100,30 +100,26 @@ class WaveConstantNode(base.MaxwellSimNode):
|
|||
run_on_init=True,
|
||||
)
|
||||
def on_use_range_changed(self, props: dict) -> None:
|
||||
"""Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.LazyArrayRange`."""
|
||||
"""Synchronize the `active_kind` of input/output sockets, to either produce a `ct.FlowKind.Value` or a `ct.FlowKind.Range`."""
|
||||
if self.inputs.get('WL') is not None:
|
||||
active_input = self.inputs['WL']
|
||||
else:
|
||||
active_input = self.inputs['Freq']
|
||||
|
||||
# Modify Active Kind(s)
|
||||
## Input active_kind -> Value/LazyArrayRange
|
||||
active_input_uses_range = active_input.active_kind == ct.FlowKind.LazyArrayRange
|
||||
## Input active_kind -> Value/Range
|
||||
active_input_uses_range = active_input.active_kind == ct.FlowKind.Range
|
||||
if active_input_uses_range != props['use_range']:
|
||||
active_input.active_kind = (
|
||||
ct.FlowKind.LazyArrayRange if props['use_range'] else ct.FlowKind.Value
|
||||
ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
|
||||
)
|
||||
|
||||
## Output active_kind -> Value/LazyArrayRange
|
||||
## Output active_kind -> Value/Range
|
||||
for active_output in self.outputs.values():
|
||||
active_output_uses_range = (
|
||||
active_output.active_kind == ct.FlowKind.LazyArrayRange
|
||||
)
|
||||
active_output_uses_range = active_output.active_kind == ct.FlowKind.Range
|
||||
if active_output_uses_range != props['use_range']:
|
||||
active_output.active_kind = (
|
||||
ct.FlowKind.LazyArrayRange
|
||||
if props['use_range']
|
||||
else ct.FlowKind.Value
|
||||
ct.FlowKind.Range if props['use_range'] else ct.FlowKind.Value
|
||||
)
|
||||
|
||||
####################
|
||||
|
@ -161,11 +157,11 @@ class WaveConstantNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'WL',
|
||||
kind=ct.FlowKind.LazyArrayRange,
|
||||
kind=ct.FlowKind.Range,
|
||||
input_sockets={'WL', 'Freq'},
|
||||
input_socket_kinds={
|
||||
'WL': ct.FlowKind.LazyArrayRange,
|
||||
'Freq': ct.FlowKind.LazyArrayRange,
|
||||
'WL': ct.FlowKind.Range,
|
||||
'Freq': ct.FlowKind.Range,
|
||||
},
|
||||
input_sockets_optional={'WL': True, 'Freq': True},
|
||||
)
|
||||
|
@ -176,7 +172,7 @@ class WaveConstantNode(base.MaxwellSimNode):
|
|||
return input_sockets['WL']
|
||||
|
||||
freq = input_sockets['Freq']
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=spux.scale_to_unit(
|
||||
sci_constants.vac_speed_of_light / (freq.stop * freq.unit), spu.um
|
||||
),
|
||||
|
@ -190,11 +186,11 @@ class WaveConstantNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'Freq',
|
||||
kind=ct.FlowKind.LazyArrayRange,
|
||||
kind=ct.FlowKind.Range,
|
||||
input_sockets={'WL', 'Freq'},
|
||||
input_socket_kinds={
|
||||
'WL': ct.FlowKind.LazyArrayRange,
|
||||
'Freq': ct.FlowKind.LazyArrayRange,
|
||||
'WL': ct.FlowKind.Range,
|
||||
'Freq': ct.FlowKind.Range,
|
||||
},
|
||||
input_sockets_optional={'WL': True, 'Freq': True},
|
||||
)
|
||||
|
@ -205,7 +201,7 @@ class WaveConstantNode(base.MaxwellSimNode):
|
|||
return input_sockets['Freq']
|
||||
|
||||
wl = input_sockets['WL']
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=spux.scale_to_unit(
|
||||
sci_constants.vac_speed_of_light / (wl.stop * wl.unit), spux.THz
|
||||
),
|
||||
|
|
|
@ -115,11 +115,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
|||
output_sockets: typ.ClassVar = {
|
||||
'Medium': sockets.MaxwellMediumSocketDef(),
|
||||
'Valid Freqs': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
),
|
||||
'Valid WLs': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Length,
|
||||
),
|
||||
}
|
||||
|
@ -254,11 +254,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'Valid Freqs',
|
||||
kind=ct.FlowKind.LazyArrayRange,
|
||||
kind=ct.FlowKind.Range,
|
||||
props={'freq_range'},
|
||||
)
|
||||
def compute_valid_freqs_lazy(self, props) -> sp.Expr:
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=props['freq_range'][0] / spux.THz,
|
||||
stop=props['freq_range'][1] / spux.THz,
|
||||
steps=0,
|
||||
|
@ -268,11 +268,11 @@ class LibraryMediumNode(base.MaxwellSimNode):
|
|||
|
||||
@events.computes_output_socket(
|
||||
'Valid WLs',
|
||||
kind=ct.FlowKind.LazyArrayRange,
|
||||
kind=ct.FlowKind.Range,
|
||||
props={'wl_range'},
|
||||
)
|
||||
def compute_valid_wls_lazy(self, props) -> sp.Expr:
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=props['wl_range'][0] / spu.nm,
|
||||
stop=props['wl_range'][0] / spu.nm,
|
||||
steps=0,
|
||||
|
|
|
@ -63,7 +63,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
input_socket_sets: typ.ClassVar = {
|
||||
'Freq Domain': {
|
||||
'Freqs': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
default_unit=spux.THz,
|
||||
default_min=374.7406, ## 800nm
|
||||
|
@ -73,7 +73,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
},
|
||||
'Time Domain': {
|
||||
't Range': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
default_unit=spu.picosecond,
|
||||
default_min=0,
|
||||
|
@ -119,7 +119,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
'Freqs',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'Freqs': ct.FlowKind.LazyArrayRange,
|
||||
'Freqs': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
|
@ -160,7 +160,7 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
't Stride',
|
||||
},
|
||||
input_socket_kinds={
|
||||
't Range': ct.FlowKind.LazyArrayRange,
|
||||
't Range': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
|
|
|
@ -63,7 +63,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
|||
input_socket_sets: typ.ClassVar = {
|
||||
'Freq Domain': {
|
||||
'Freqs': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
default_unit=spux.THz,
|
||||
default_min=374.7406, ## 800nm
|
||||
|
@ -73,7 +73,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
|||
},
|
||||
'Time Domain': {
|
||||
't Range': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
default_unit=spu.picosecond,
|
||||
default_min=0,
|
||||
|
@ -137,7 +137,7 @@ class PowerFluxMonitorNode(base.MaxwellSimNode):
|
|||
'Freqs',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'Freqs': ct.FlowKind.LazyArrayRange,
|
||||
'Freqs': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
|
|
|
@ -58,7 +58,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
|
|||
abs_min=0,
|
||||
),
|
||||
'Freqs': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
default_unit=spux.THz,
|
||||
default_min=374.7406, ## 800nm
|
||||
|
@ -87,7 +87,7 @@ class PermittivityMonitorNode(base.MaxwellSimNode):
|
|||
'Freqs',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'Freqs': ct.FlowKind.LazyArrayRange,
|
||||
'Freqs': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
|
|
|
@ -14,16 +14,15 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# from . import file_exporters, viewer, web_exporters
|
||||
from . import viewer, web_exporters
|
||||
from . import file_exporters, viewer, web_exporters
|
||||
|
||||
BL_REGISTER = [
|
||||
*viewer.BL_REGISTER,
|
||||
# *file_exporters.BL_REGISTER,
|
||||
*file_exporters.BL_REGISTER,
|
||||
*web_exporters.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**viewer.BL_NODES,
|
||||
# **file_exporters.BL_NODES,
|
||||
**file_exporters.BL_NODES,
|
||||
**web_exporters.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -14,11 +14,15 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from . import json_file_exporter
|
||||
from . import data_file_exporter
|
||||
|
||||
# from . import json_file_exporter
|
||||
|
||||
BL_REGISTER = [
|
||||
*json_file_exporter.BL_REGISTER,
|
||||
*data_file_exporter.BL_REGISTER,
|
||||
# *json_file_exporter.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**json_file_exporter.BL_NODES,
|
||||
**data_file_exporter.BL_NODES,
|
||||
# **json_file_exporter.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
# blender_maxwell
|
||||
# Copyright (C) 2024 blender_maxwell Project Contributors
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import typing as typ
|
||||
from pathlib import Path
|
||||
|
||||
import bpy
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
####################
|
||||
# - Operators
|
||||
####################
|
||||
class ExportDataFile(bpy.types.Operator):
|
||||
"""Exports data from the input to `DataFileExporterNode` to the file path given on the same node, if the path is compatible with the chosen export format (a property on the node)."""
|
||||
|
||||
bl_idname = ct.OperatorType.NodeExportDataFile
|
||||
bl_label = 'Save Data File'
|
||||
bl_description = 'Save a file with the contents, name, and format indicated by a NodeExportDataFile'
|
||||
|
||||
@classmethod
|
||||
def poll(cls, context):
|
||||
return (
|
||||
# Check Node
|
||||
hasattr(context, 'node')
|
||||
and hasattr(context.node, 'node_type')
|
||||
and (node := context.node).node_type == ct.NodeType.DataFileExporter
|
||||
# Check Expr
|
||||
and node.is_file_path_compatible_with_export_format
|
||||
)
|
||||
|
||||
def execute(self, context: bpy.types.Context):
|
||||
node = context.node
|
||||
|
||||
node.export_format.saver(node.file_path, node.expr_data, node.expr_info)
|
||||
return {'FINISHED'}
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class DataFileExporterNode(base.MaxwellSimNode):
|
||||
# """Export input data to a supported
|
||||
node_type = ct.NodeType.DataFileExporter
|
||||
bl_label = 'Data File Importer'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Func),
|
||||
'File Path': sockets.FilePathSocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties: Expr Info
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_expr = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
if has_expr:
|
||||
self.expr_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'file_path'})
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
"""Retrieve the input expression's `InfoFlow`."""
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info
|
||||
return None
|
||||
|
||||
@property
|
||||
def expr_data(self) -> typ.Any | None:
|
||||
"""Retrieve the input expression's data by evaluating its `Func`."""
|
||||
func = self._compute_input('Expr', kind=ct.FlowKind.Func)
|
||||
params = self._compute_input('Expr', kind=ct.FlowKind.Params)
|
||||
|
||||
has_func = not ct.FlowSignal.check(func)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
if has_func and has_params:
|
||||
symbol_values = {
|
||||
sym.name: self._compute_input(sym.name, kind=ct.FlowKind.Value)
|
||||
for sym in params.sorted_symbols
|
||||
}
|
||||
return func.func_jax(
|
||||
*params.scaled_func_args(spux.UNITS_SI, symbol_values=symbol_values),
|
||||
**params.scaled_func_kwargs(spux.UNITS_SI, symbol_values=symbol_values),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Properties: File Path
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name={'File Path'},
|
||||
input_sockets={'File Path'},
|
||||
input_socket_kinds={'File Path': ct.FlowKind.Value},
|
||||
input_sockets_optional={'File Path': True},
|
||||
)
|
||||
def on_file_path_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
|
||||
if has_file_path:
|
||||
self.file_path = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def file_path(self) -> Path:
|
||||
"""Retrieve the input file path."""
|
||||
file_path = self._compute_input(
|
||||
'File Path', kind=ct.FlowKind.Value, optional=True
|
||||
)
|
||||
has_file_path = not ct.FlowSignal.check(file_path)
|
||||
if has_file_path:
|
||||
return file_path
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Properties: Export Format
|
||||
####################
|
||||
export_format: ct.DataFileFormat = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_export_formats(),
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
|
||||
def search_export_formats(self):
|
||||
if self.expr_info is not None:
|
||||
return [
|
||||
data_file_format.bl_enum_element(i)
|
||||
for i, data_file_format in enumerate(list(ct.DataFileFormat))
|
||||
if data_file_format.is_info_compatible(self.expr_info)
|
||||
]
|
||||
return ct.DataFileFormat.bl_enum_elements()
|
||||
|
||||
####################
|
||||
# - Properties: File Path Compatibility
|
||||
####################
|
||||
@bl_cache.cached_bl_property(depends_on={'file_path', 'export_format'})
|
||||
def is_file_path_compatible_with_export_format(self) -> bool | None:
|
||||
"""Determine whether the given file path is actually compatible with the desired export format."""
|
||||
if self.file_path is not None and self.export_format is not None:
|
||||
return self.export_format.is_path_compatible(self.file_path)
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
"""Show the extracted file name (w/extension) in the node's header label.
|
||||
|
||||
Notes:
|
||||
Called by Blender to determine the text to place in the node's header.
|
||||
"""
|
||||
if self.file_path is not None:
|
||||
return 'Save: ' + self.file_path.name
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
"""Show information about the loaded file."""
|
||||
if self.export_format is not None:
|
||||
box = layout.box()
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Data File')
|
||||
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=self.file_path.name)
|
||||
|
||||
compatibility = self.is_file_path_compatible_with_export_format
|
||||
if compatibility is not None:
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
if compatibility:
|
||||
row.label(text='Valid Path | Format', icon='CHECKMARK')
|
||||
else:
|
||||
row.label(text='Invalid Path | Format', icon='ERROR')
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['export_format'], text='')
|
||||
layout.operator(ct.OperatorType.NodeExportDataFile, text='Save Data File')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name='Expr',
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
|
||||
input_sockets_optional={'Expr': True},
|
||||
)
|
||||
def on_expr_changed(self, input_sockets: dict) -> None:
|
||||
"""Declare any loose input sockets needed to realize the input expr's symbols."""
|
||||
info = input_sockets['Expr'][ct.FlowKind.Info]
|
||||
params = input_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
# Provide Sockets for Symbol Realization
|
||||
## -> Only happens if Params contains not-yet-realized symbols.
|
||||
if has_info and has_params and params.symbols:
|
||||
if set(self.loose_input_sockets) != {
|
||||
dim.name for dim in params.symbols if dim in info.dims
|
||||
}:
|
||||
self.loose_input_sockets = {
|
||||
dim_name: sockets.ExprSocketDef(**expr_info)
|
||||
for dim_name, expr_info in params.sym_expr_infos(info).items()
|
||||
}
|
||||
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
ExportDataFile,
|
||||
DataFileExporterNode,
|
||||
]
|
||||
BL_NODES = {
|
||||
ct.NodeType.DataFileExporter: (ct.NodeCategory.MAXWELLSIM_OUTPUTS_FILEEXPORTERS)
|
||||
}
|
|
@ -74,7 +74,7 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
},
|
||||
'Symbolic': {
|
||||
't Range': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyArrayRange,
|
||||
active_kind=ct.FlowKind.Range,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
default_unit=spu.picosecond,
|
||||
default_min=0,
|
||||
|
@ -132,8 +132,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
'Envelope',
|
||||
},
|
||||
input_socket_kinds={
|
||||
't Range': ct.FlowKind.LazyArrayRange,
|
||||
'Envelope': ct.FlowKind.LazyValueFunc,
|
||||
't Range': ct.FlowKind.Range,
|
||||
'Envelope': ct.FlowKind.Func,
|
||||
},
|
||||
input_sockets_optional={
|
||||
'max E': True,
|
||||
|
|
|
@ -525,9 +525,9 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Array", but socket does not define it'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# LazyValueFunc
|
||||
# Func
|
||||
@property
|
||||
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
"""Throws a descriptive error.
|
||||
|
||||
Notes:
|
||||
|
@ -538,8 +538,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
"""
|
||||
return ct.FlowSignal.NoFlow
|
||||
|
||||
@lazy_value_func.setter
|
||||
def lazy_value_func(self, lazy_value_func: ct.LazyValueFuncFlow) -> None:
|
||||
@lazy_func.setter
|
||||
def lazy_func(self, lazy_func: ct.FuncFlow) -> None:
|
||||
"""Throws a descriptive error.
|
||||
|
||||
Notes:
|
||||
|
@ -548,12 +548,12 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
Raises:
|
||||
NotImplementedError: When used without being overridden.
|
||||
"""
|
||||
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.LazyValueFunc", but socket does not define it'
|
||||
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Func", but socket does not define it'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# LazyArrayRange
|
||||
# Range
|
||||
@property
|
||||
def lazy_array_range(self) -> ct.LazyArrayRangeFlow:
|
||||
def lazy_range(self) -> ct.RangeFlow:
|
||||
"""Throws a descriptive error.
|
||||
|
||||
Notes:
|
||||
|
@ -564,8 +564,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
"""
|
||||
return ct.FlowSignal.NoFlow
|
||||
|
||||
@lazy_array_range.setter
|
||||
def lazy_array_range(self, value: ct.LazyArrayRangeFlow) -> None:
|
||||
@lazy_range.setter
|
||||
def lazy_range(self, value: ct.RangeFlow) -> None:
|
||||
"""Throws a descriptive error.
|
||||
|
||||
Notes:
|
||||
|
@ -574,7 +574,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
Raises:
|
||||
NotImplementedError: When used without being overridden.
|
||||
"""
|
||||
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.LazyArrayRange", but socket does not define it'
|
||||
msg = f'Socket {self.bl_label} {self.socket_type}): Tried to set "ct.FlowKind.Range", but socket does not define it'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
####################
|
||||
|
@ -595,8 +595,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
kind_data_map = {
|
||||
ct.FlowKind.Value: lambda: self.value,
|
||||
ct.FlowKind.Array: lambda: self.array,
|
||||
ct.FlowKind.LazyValueFunc: lambda: self.lazy_value_func,
|
||||
ct.FlowKind.LazyArrayRange: lambda: self.lazy_array_range,
|
||||
ct.FlowKind.Func: lambda: self.lazy_func,
|
||||
ct.FlowKind.Range: lambda: self.lazy_range,
|
||||
ct.FlowKind.Params: lambda: self.params,
|
||||
ct.FlowKind.Info: lambda: self.info,
|
||||
}
|
||||
|
@ -783,8 +783,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
{
|
||||
ct.FlowKind.Value: self.draw_value,
|
||||
ct.FlowKind.Array: self.draw_array,
|
||||
ct.FlowKind.LazyArrayRange: self.draw_lazy_array_range,
|
||||
ct.FlowKind.LazyValueFunc: self.draw_lazy_value_func,
|
||||
ct.FlowKind.Range: self.draw_lazy_range,
|
||||
ct.FlowKind.Func: self.draw_lazy_func,
|
||||
}[self.active_kind](col)
|
||||
|
||||
# Info Drawing
|
||||
|
@ -894,11 +894,11 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
col: Target for defining UI elements.
|
||||
"""
|
||||
|
||||
def draw_lazy_array_range(self, col: bpy.types.UILayout) -> None:
|
||||
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
|
||||
"""Draws the socket lazy array range on its own line.
|
||||
|
||||
Notes:
|
||||
Should be overriden by individual socket classes, if they have an editable `FlowKind.LazyArrayRange`.
|
||||
Should be overriden by individual socket classes, if they have an editable `FlowKind.Range`.
|
||||
|
||||
Parameters:
|
||||
col: Target for defining UI elements.
|
||||
|
@ -914,11 +914,11 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
col: Target for defining UI elements.
|
||||
"""
|
||||
|
||||
def draw_lazy_value_func(self, col: bpy.types.UILayout) -> None:
|
||||
def draw_lazy_func(self, col: bpy.types.UILayout) -> None:
|
||||
"""Draws the socket lazy value function UI on its own line.
|
||||
|
||||
Notes:
|
||||
Should be overriden by individual socket classes, if they have an editable `FlowKind.LazyValueFunc`.
|
||||
Should be overriden by individual socket classes, if they have an editable `FlowKind.Func`.
|
||||
|
||||
Parameters:
|
||||
col: Target for defining UI elements.
|
||||
|
|
|
@ -100,7 +100,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
When active, `self.active_unit` can be used via the UI to select valid unit of the given `self.physical_type`, and `self.unit` works.
|
||||
The enum itself can be dynamically altered, ex. via its UI dropdown support.
|
||||
symbols: The symbolic variables valid in the context of the expression.
|
||||
Various features, including `LazyValueFunc` support, become available when symbols are in use.
|
||||
Various features, including `Func` support, become available when symbols are in use.
|
||||
The presence of symbols forces fallback to a string-based `sympy` expression UI.
|
||||
|
||||
active_unit: The currently active unit, as a dropdown.
|
||||
|
@ -118,13 +118,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
physical_type: spux.PhysicalType = bl_cache.BLField(spux.PhysicalType.NonPhysical)
|
||||
|
||||
# Symbols
|
||||
# active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
||||
symbols: frozenset[sp.Symbol] = bl_cache.BLField(frozenset())
|
||||
output_name: sim_symbols.SimSymbolName = bl_cache.BLField(
|
||||
sim_symbols.SimSymbolName.Expr
|
||||
)
|
||||
active_symbols: list[sim_symbols.SimSymbol] = bl_cache.BLField([])
|
||||
|
||||
# @property
|
||||
# def symbols(self) -> set[sp.Symbol]:
|
||||
# """Current symbols as an unordered set."""
|
||||
# return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
|
||||
@property
|
||||
def symbols(self) -> set[sp.Symbol]:
|
||||
"""Current symbols as an unordered set."""
|
||||
return {sim_symbol.sp_symbol for sim_symbol in self.active_symbols}
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'symbols'})
|
||||
def sorted_symbols(self) -> list[sp.Symbol]:
|
||||
|
@ -169,7 +171,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)), float_prec=4
|
||||
)
|
||||
|
||||
# UI: LazyArrayRange
|
||||
# UI: Range
|
||||
steps: int = bl_cache.BLField(2, soft_min=2, abs_min=0)
|
||||
scaling: ct.ScalingMode = bl_cache.BLField(ct.ScalingMode.Lin)
|
||||
## Expression
|
||||
|
@ -184,6 +186,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
# UI: Info
|
||||
show_func_ui: bool = bl_cache.BLField(True)
|
||||
show_info_columns: bool = bl_cache.BLField(False)
|
||||
info_columns: set[InfoDisplayCol] = bl_cache.BLField(
|
||||
{InfoDisplayCol.Length, InfoDisplayCol.MathType}
|
||||
|
@ -248,7 +251,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
and not self.symbols
|
||||
):
|
||||
self.value = self.value.subs({self.unit: prev_unit})
|
||||
self.lazy_array_range = self.lazy_array_range.correct_unit(prev_unit)
|
||||
self.lazy_range = self.lazy_range.correct_unit(prev_unit)
|
||||
|
||||
self.prev_unit = self.active_unit
|
||||
|
||||
|
@ -454,20 +457,20 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
####################
|
||||
# - FlowKind: LazyArrayRange
|
||||
# - FlowKind: Range
|
||||
####################
|
||||
@property
|
||||
def lazy_array_range(self) -> ct.LazyArrayRangeFlow:
|
||||
def lazy_range(self) -> ct.RangeFlow:
|
||||
"""Return the not-yet-computed uniform array defined by the socket.
|
||||
|
||||
Notes:
|
||||
Called to compute the internal `FlowKind.LazyArrayRange` of this socket.
|
||||
Called to compute the internal `FlowKind.Range` of this socket.
|
||||
|
||||
Return:
|
||||
The range of lengths, which uses no symbols.
|
||||
"""
|
||||
if self.symbols:
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=self.raw_min_sp,
|
||||
stop=self.raw_max_sp,
|
||||
steps=self.steps,
|
||||
|
@ -493,7 +496,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
],
|
||||
}[self.mathtype]()
|
||||
|
||||
return ct.LazyArrayRangeFlow(
|
||||
return ct.RangeFlow(
|
||||
start=min_bound,
|
||||
stop=max_bound,
|
||||
steps=self.steps,
|
||||
|
@ -501,12 +504,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
unit=self.unit,
|
||||
)
|
||||
|
||||
@lazy_array_range.setter
|
||||
def lazy_array_range(self, value: ct.LazyArrayRangeFlow) -> None:
|
||||
@lazy_range.setter
|
||||
def lazy_range(self, value: ct.RangeFlow) -> None:
|
||||
"""Set the not-yet-computed uniform array defined by the socket.
|
||||
|
||||
Notes:
|
||||
Called to compute the internal `FlowKind.LazyArrayRange` of this socket.
|
||||
Called to compute the internal `FlowKind.Range` of this socket.
|
||||
"""
|
||||
self.steps = value.steps
|
||||
self.scaling = value.scaling
|
||||
|
@ -544,20 +547,20 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
]
|
||||
|
||||
####################
|
||||
# - FlowKind: LazyValueFunc (w/Params if Constant)
|
||||
# - FlowKind: Func (w/Params if Constant)
|
||||
####################
|
||||
@property
|
||||
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
"""Returns a lazy value that computes the expression returned by `self.value`.
|
||||
|
||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `LazyValueFuncFlow`.
|
||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be the arguments of the `FuncFlow`.
|
||||
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
||||
"""
|
||||
# Symbolic
|
||||
## -> `self.value` is guaranteed to be an expression with unknowns.
|
||||
## -> The function computes `self.value` with unknowns as arguments.
|
||||
if self.symbols:
|
||||
return ct.LazyValueFuncFlow(
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(
|
||||
self.sorted_symbols,
|
||||
spux.scale_to_unit(self.value, self.unit),
|
||||
|
@ -572,7 +575,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
## -> ("Dummy" as in returns the same argument that it takes).
|
||||
## -> This is an excuse to let `ParamsFlow` pass `self.value` verbatim.
|
||||
## -> Generally only useful for operations with other expressions.
|
||||
return ct.LazyValueFuncFlow(
|
||||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[
|
||||
self.physical_type if self.physical_type is not None else self.mathtype
|
||||
|
@ -582,7 +585,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
@property
|
||||
def params(self) -> ct.ParamsFlow:
|
||||
"""Returns parameter symbols/values to accompany `self.lazy_value_func`.
|
||||
"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||
|
||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
||||
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
||||
|
@ -605,45 +608,34 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
@property
|
||||
def info(self) -> ct.ArrayFlow:
|
||||
r"""Returns parameter symbols/values to accompany `self.lazy_value_func`.
|
||||
r"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||
|
||||
The output name/size/mathtype/unit corresponds directly the `ExprSocket`.
|
||||
|
||||
If `self.symbols` has entries, then these will propagate as dimensions with unresolvable `LazyArrayRangeFlow` index descriptions.
|
||||
If `self.symbols` has entries, then these will propagate as dimensions with unresolvable `RangeFlow` index descriptions.
|
||||
The index range will be $(-\infty,\infty)$, with $0$ steps and no unit.
|
||||
The order/naming matches `self.params` and `self.lazy_value_func`.
|
||||
The order/naming matches `self.params` and `self.lazy_func`.
|
||||
|
||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||
"""
|
||||
output_sim_sym = (
|
||||
sim_symbols.SimSymbol(
|
||||
sym_name=self.output_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
rows=self.size.rows,
|
||||
cols=self.size.cols,
|
||||
),
|
||||
)
|
||||
if self.symbols:
|
||||
return ct.InfoFlow(
|
||||
dim_names=[sym.name for sym in self.sorted_symbols],
|
||||
dim_idx={
|
||||
sym.name: ct.LazyArrayRangeFlow(
|
||||
start=-sp.oo if _check_sym_oo(sym) else -sp.zoo,
|
||||
stop=sp.oo if _check_sym_oo(sym) else sp.zoo,
|
||||
steps=0,
|
||||
unit=None, ## Symbols alone are unitless.
|
||||
)
|
||||
## TODO: PhysicalTypes for symbols? Or nah?
|
||||
## TODO: Can we parse some sp.Interval for explicit domains?
|
||||
## -> We investigated sp.Symbol(..., domain=...).
|
||||
## -> It's no good. We can't re-extract the interval given to domain.
|
||||
for sym in self.sorted_symbols
|
||||
},
|
||||
output_name='_', ## Use node:socket name? Or something? Ahh
|
||||
output_shape=self.size.shape,
|
||||
output_mathtype=self.mathtype,
|
||||
output_unit=self.unit,
|
||||
dims={sim_sym: None for sim_sym in self.active_symbols},
|
||||
output=output_sim_sym,
|
||||
)
|
||||
|
||||
# Constant
|
||||
return ct.InfoFlow(
|
||||
output_name='_', ## Use node:socket name? Or something? Ahh
|
||||
output_shape=self.size.shape,
|
||||
output_mathtype=self.mathtype,
|
||||
output_unit=self.unit,
|
||||
)
|
||||
return ct.InfoFlow(output=output_sim_sym)
|
||||
|
||||
####################
|
||||
# - FlowKind: Capabilities
|
||||
|
@ -805,13 +797,13 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
for sym in self.symbols:
|
||||
col.label(text=spux.pretty_symbol(sym))
|
||||
|
||||
def draw_lazy_array_range(self, col: bpy.types.UILayout) -> None:
|
||||
def draw_lazy_range(self, col: bpy.types.UILayout) -> None:
|
||||
"""Draw the socket body for a simple, uniform range of values between two values/expressions.
|
||||
|
||||
Drawn when `self.active_kind == FlowKind.LazyArrayRange`.
|
||||
Drawn when `self.active_kind == FlowKind.Range`.
|
||||
|
||||
Notes:
|
||||
If `self.steps == 0`, then the `LazyArrayRange` is considered to have a to-be-determined number of steps.
|
||||
If `self.steps == 0`, then the `Range` is considered to have a to-be-determined number of steps.
|
||||
As such, `self.steps` won't be exposed in the UI.
|
||||
"""
|
||||
if self.symbols:
|
||||
|
@ -835,44 +827,49 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
if self.steps != 0:
|
||||
col.prop(self, self.blfields['steps'], text='')
|
||||
|
||||
def draw_lazy_value_func(self, col: bpy.types.UILayout) -> None:
|
||||
def draw_lazy_func(self, col: bpy.types.UILayout) -> None:
|
||||
"""Draw the socket body for a single flexible value/expression, for down-chain lazy evaluation.
|
||||
|
||||
This implements the most flexible variant of the `ExprSocket` UI, providing the user with full runtime-configuration of the exact `self.size`, `self.mathtype`, `self.physical_type`, and `self.symbols` of the expression.
|
||||
|
||||
Notes:
|
||||
Drawn when `self.active_kind == FlowKind.LazyValueFunc`.
|
||||
Drawn when `self.active_kind == FlowKind.Func`.
|
||||
|
||||
This is an ideal choice for ex. math nodes that need to accept arbitrary expressions as inputs, with an eye towards lazy evaluation of ex. symbolic terms.
|
||||
|
||||
Uses `draw_value` to draw the base UI
|
||||
"""
|
||||
# Physical Type Selector
|
||||
## -> Determines whether/which unit-dropdown will be shown.
|
||||
col.prop(self, self.blfields['physical_type'], text='')
|
||||
if self.show_func_ui:
|
||||
# Output Name Selector
|
||||
## -> The name of the output
|
||||
col.prop(self, self.blfields['output_name'], text='')
|
||||
|
||||
# Non-Symbolic: Size/Mathtype Selector
|
||||
## -> Symbols imply str expr input.
|
||||
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
|
||||
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
|
||||
if not self.symbols:
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['size'], text='')
|
||||
row.prop(self, self.blfields['mathtype'], text='')
|
||||
# Physical Type Selector
|
||||
## -> Determines whether/which unit-dropdown will be shown.
|
||||
col.prop(self, self.blfields['physical_type'], text='')
|
||||
|
||||
# Base UI
|
||||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
self.draw_value(col)
|
||||
# Non-Symbolic: Size/Mathtype Selector
|
||||
## -> Symbols imply str expr input.
|
||||
## -> For arbitrary str exprs, size/mathtype are derived from the expr.
|
||||
## -> Otherwise, size/mathtype must be pre-specified for a nice UI.
|
||||
if not self.symbols:
|
||||
row = col.row(align=True)
|
||||
row.prop(self, self.blfields['size'], text='')
|
||||
row.prop(self, self.blfields['mathtype'], text='')
|
||||
|
||||
# Symbol UI
|
||||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
## -> TODO
|
||||
# Base UI
|
||||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
self.draw_value(col)
|
||||
|
||||
# Symbol UI
|
||||
## -> Draws the UI appropriate for the above choice of constraints.
|
||||
## -> TODO
|
||||
|
||||
####################
|
||||
# - UI: InfoFlow
|
||||
####################
|
||||
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
||||
if self.active_kind == ct.FlowKind.LazyValueFunc and self.show_info_columns:
|
||||
if self.active_kind == ct.FlowKind.Func and self.show_info_columns:
|
||||
row = col.row()
|
||||
box = row.box()
|
||||
grid = box.grid_flow(
|
||||
|
@ -884,9 +881,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
# Dimensions
|
||||
for dim_name in info.dim_names:
|
||||
dim_idx = info.dim_idx[dim_name]
|
||||
grid.label(text=dim_name)
|
||||
for dim in info.dims:
|
||||
dim_idx = info.dims[dim]
|
||||
grid.label(text=dim.name_pretty)
|
||||
if InfoDisplayCol.Length in self.info_columns:
|
||||
grid.label(text=str(len(dim_idx)))
|
||||
if InfoDisplayCol.MathType in self.info_columns:
|
||||
|
@ -895,27 +892,27 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
||||
|
||||
# Outputs
|
||||
grid.label(text=info.output_name)
|
||||
grid.label(text=info.output.name_pretty)
|
||||
if InfoDisplayCol.Length in self.info_columns:
|
||||
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
||||
if InfoDisplayCol.MathType in self.info_columns:
|
||||
grid.label(
|
||||
text=(
|
||||
spux.MathType.to_str(info.output_mathtype)
|
||||
spux.MathType.to_str(info.output.mathtype)
|
||||
+ (
|
||||
'ˣ'.join(
|
||||
[
|
||||
unicode_superscript(out_axis)
|
||||
for out_axis in info.output_shape
|
||||
for out_axis in info.output.shape
|
||||
]
|
||||
)
|
||||
if info.output_shape
|
||||
if info.output.shape
|
||||
else ''
|
||||
)
|
||||
)
|
||||
)
|
||||
if InfoDisplayCol.Unit in self.info_columns:
|
||||
grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
|
||||
grid.label(text=f'{spux.sp_to_str(info.output.unit)}')
|
||||
|
||||
|
||||
####################
|
||||
|
@ -925,10 +922,11 @@ class ExprSocketDef(base.SocketDef):
|
|||
socket_type: ct.SocketType = ct.SocketType.Expr
|
||||
active_kind: typ.Literal[
|
||||
ct.FlowKind.Value,
|
||||
ct.FlowKind.LazyArrayRange,
|
||||
ct.FlowKind.Range,
|
||||
ct.FlowKind.Array,
|
||||
ct.FlowKind.LazyValueFunc,
|
||||
ct.FlowKind.Func,
|
||||
] = ct.FlowKind.Value
|
||||
output_name: sim_symbols.SimSymbolName = sim_symbols.SimSymbolName
|
||||
|
||||
# Socket Interface
|
||||
size: spux.NumberSize1D = spux.NumberSize1D.Scalar
|
||||
|
@ -938,22 +936,19 @@ class ExprSocketDef(base.SocketDef):
|
|||
default_unit: spux.Unit | None = None
|
||||
default_symbols: list[sim_symbols.SimSymbol] = pyd.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def symbols(self) -> set[sp.Symbol]:
|
||||
return {sim_symbol.sp_symbol for sim_symbol in self.default_symbols}
|
||||
|
||||
# FlowKind: Value
|
||||
default_value: spux.SympyExpr = 0
|
||||
abs_min: spux.SympyExpr | None = None
|
||||
abs_max: spux.SympyExpr | None = None
|
||||
|
||||
# FlowKind: LazyArrayRange
|
||||
# FlowKind: Range
|
||||
default_min: spux.SympyExpr = 0
|
||||
default_max: spux.SympyExpr = 1
|
||||
default_steps: int = 2
|
||||
default_scaling: ct.ScalingMode = ct.ScalingMode.Lin
|
||||
|
||||
# UI
|
||||
show_func_ui: bool = True
|
||||
show_info_columns: bool = False
|
||||
|
||||
####################
|
||||
|
@ -1107,7 +1102,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
return self
|
||||
|
||||
####################
|
||||
# - Parse FlowKind.LazyArrayRange
|
||||
# - Parse FlowKind.Range
|
||||
####################
|
||||
@pyd.field_validator('default_steps')
|
||||
@classmethod
|
||||
|
@ -1120,8 +1115,8 @@ class ExprSocketDef(base.SocketDef):
|
|||
return v
|
||||
|
||||
@pyd.model_validator(mode='after')
|
||||
def parse_default_lazy_array_range_numbers(self) -> typ.Self:
|
||||
"""Guarantees that the default `ct.LazyArrayRange` bounds are sympy expressions.
|
||||
def parse_default_lazy_range_numbers(self) -> typ.Self:
|
||||
"""Guarantees that the default `ct.Range` bounds are sympy expressions.
|
||||
|
||||
If `self.default_value` is a scalar Python type, it will be coerced into the corresponding Sympy type using `sp.S`.
|
||||
|
||||
|
@ -1150,9 +1145,17 @@ class ExprSocketDef(base.SocketDef):
|
|||
if mathtype_guide == 'expr':
|
||||
dv_mathtype = spux.MathType.from_expr(bound)
|
||||
if not self.mathtype.is_compatible(dv_mathtype):
|
||||
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default LazyArrayRange min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
|
||||
msg = f'ExprSocket: Mathtype {dv_mathtype} of a default Range min or max expression {bound} (type {type(self.default_value)}) is incompatible with socket MathType {self.mathtype}'
|
||||
raise ValueError(msg)
|
||||
|
||||
# Coerce from Infinite
|
||||
if bound.is_infinite and self.mathtype is spux.MathType.Integer:
|
||||
new_bounds[i] = sp.S(-1) if i == 0 else sp.S(1)
|
||||
if bound.is_infinite and self.mathtype is spux.MathType.Rational:
|
||||
new_bounds[i] = sp.Rational(-1, 1) if i == 0 else sp.Rational(1, 1)
|
||||
if bound.is_infinite and self.mathtype is spux.MathType.Real:
|
||||
new_bounds[i] = sp.S(-1.0) if i == 0 else sp.S(1.0)
|
||||
|
||||
if new_bounds[0] is not None:
|
||||
self.default_min = new_bounds[0]
|
||||
if new_bounds[1] is not None:
|
||||
|
@ -1161,8 +1164,8 @@ class ExprSocketDef(base.SocketDef):
|
|||
return self
|
||||
|
||||
@pyd.model_validator(mode='after')
|
||||
def parse_default_lazy_array_range_size(self) -> typ.Self:
|
||||
"""Guarantees that the default `ct.LazyArrayRange` bounds are unshaped.
|
||||
def parse_default_lazy_range_size(self) -> typ.Self:
|
||||
"""Guarantees that the default `ct.Range` bounds are unshaped.
|
||||
|
||||
Raises:
|
||||
ValueError: If `self.default_min` or `self.default_max` are shaped.
|
||||
|
@ -1170,16 +1173,16 @@ class ExprSocketDef(base.SocketDef):
|
|||
# Check ActiveKind and Size
|
||||
## -> NOTE: This doesn't protect against dynamic changes to either.
|
||||
if (
|
||||
self.active_kind == ct.FlowKind.LazyArrayRange
|
||||
self.active_kind == ct.FlowKind.Range
|
||||
and self.size is not spux.NumberSize1D.Scalar
|
||||
):
|
||||
msg = "Can't have a non-Scalar size when LazyArrayRange is set as the active kind."
|
||||
msg = "Can't have a non-Scalar size when Range is set as the active kind."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check that Bounds are Shapeless
|
||||
for bound in [self.default_min, self.default_max]:
|
||||
if hasattr(bound, 'shape'):
|
||||
msg = f'ExprSocket: A default bound {bound} (type {type(bound)}) has a shape, but LazyArrayRange supports no shape in ExprSockets.'
|
||||
msg = f'ExprSocket: A default bound {bound} (type {type(bound)}) has a shape, but Range supports no shape in ExprSockets.'
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
@ -1217,13 +1220,14 @@ class ExprSocketDef(base.SocketDef):
|
|||
####################
|
||||
def init(self, bl_socket: ExprBLSocket) -> None:
|
||||
bl_socket.active_kind = self.active_kind
|
||||
bl_socket.output_name = self.output_name
|
||||
|
||||
# Socket Interface
|
||||
## -> Recall that auto-updates are turned off during init()
|
||||
bl_socket.size = self.size
|
||||
bl_socket.mathtype = self.mathtype
|
||||
bl_socket.physical_type = self.physical_type
|
||||
bl_socket.symbols = self.symbols
|
||||
bl_socket.active_symbols = self.symbols
|
||||
|
||||
# FlowKind.Value
|
||||
## -> We must take units into account when setting bl_socket.value
|
||||
|
@ -1235,9 +1239,9 @@ class ExprSocketDef(base.SocketDef):
|
|||
|
||||
bl_socket.prev_unit = bl_socket.active_unit
|
||||
|
||||
# FlowKind.LazyArrayRange
|
||||
# FlowKind.Range
|
||||
## -> We can directly pass None to unit.
|
||||
bl_socket.lazy_array_range = ct.LazyArrayRangeFlow(
|
||||
bl_socket.lazy_range = ct.RangeFlow(
|
||||
start=self.default_min,
|
||||
stop=self.default_max,
|
||||
steps=self.default_steps,
|
||||
|
@ -1246,6 +1250,7 @@ class ExprSocketDef(base.SocketDef):
|
|||
)
|
||||
|
||||
# UI
|
||||
bl_socket.show_func_ui = self.show_func_ui
|
||||
bl_socket.show_info_columns = self.show_info_columns
|
||||
|
||||
# Info Draw
|
||||
|
|
|
@ -61,7 +61,6 @@ SympyType = (
|
|||
class MathType(enum.StrEnum):
|
||||
"""Type identifiers that encompass common sets of mathematical objects."""
|
||||
|
||||
Bool = enum.auto()
|
||||
Integer = enum.auto()
|
||||
Rational = enum.auto()
|
||||
Real = enum.auto()
|
||||
|
@ -77,8 +76,6 @@ class MathType(enum.StrEnum):
|
|||
return MathType.Rational
|
||||
if MathType.Integer in mathtypes:
|
||||
return MathType.Integer
|
||||
if MathType.Bool in mathtypes:
|
||||
return MathType.Bool
|
||||
|
||||
msg = f"Can't combine mathtypes {mathtypes}"
|
||||
raise ValueError(msg)
|
||||
|
@ -88,7 +85,6 @@ class MathType(enum.StrEnum):
|
|||
return (
|
||||
other
|
||||
in {
|
||||
MT.Bool: [MT.Bool],
|
||||
MT.Integer: [MT.Integer],
|
||||
MT.Rational: [MT.Integer, MT.Rational],
|
||||
MT.Real: [MT.Integer, MT.Rational, MT.Real],
|
||||
|
@ -98,11 +94,9 @@ class MathType(enum.StrEnum):
|
|||
|
||||
def coerce_compatible_pyobj(
|
||||
self, pyobj: bool | int | Fraction | float | complex
|
||||
) -> bool | int | Fraction | float | complex:
|
||||
) -> int | Fraction | float | complex:
|
||||
MT = MathType
|
||||
match self:
|
||||
case MT.Bool:
|
||||
return pyobj
|
||||
case MT.Integer:
|
||||
return int(pyobj)
|
||||
case MT.Rational if isinstance(pyobj, int):
|
||||
|
@ -123,8 +117,6 @@ class MathType(enum.StrEnum):
|
|||
*[MathType.from_expr(v) for v in sp.flatten(sp_obj)]
|
||||
)
|
||||
|
||||
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
|
||||
return MathType.Bool
|
||||
if sp_obj.is_integer:
|
||||
return MathType.Integer
|
||||
if sp_obj.is_rational:
|
||||
|
@ -146,7 +138,6 @@ class MathType(enum.StrEnum):
|
|||
@staticmethod
|
||||
def from_pytype(dtype: type) -> type:
|
||||
return {
|
||||
bool: MathType.Bool,
|
||||
int: MathType.Integer,
|
||||
Fraction: MathType.Rational,
|
||||
float: MathType.Real,
|
||||
|
@ -166,7 +157,6 @@ class MathType(enum.StrEnum):
|
|||
def pytype(self) -> type:
|
||||
MT = MathType
|
||||
return {
|
||||
MT.Bool: bool,
|
||||
MT.Integer: int,
|
||||
MT.Rational: Fraction,
|
||||
MT.Real: float,
|
||||
|
@ -177,17 +167,25 @@ class MathType(enum.StrEnum):
|
|||
def symbolic_set(self) -> type:
|
||||
MT = MathType
|
||||
return {
|
||||
MT.Bool: sp.Set([sp.S(False), sp.S(True)]),
|
||||
MT.Integer: sp.Integers,
|
||||
MT.Rational: sp.Rationals,
|
||||
MT.Real: sp.Reals,
|
||||
MT.Complex: sp.Complexes,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def sp_symbol_a(self) -> type:
|
||||
MT = MathType
|
||||
return {
|
||||
MT.Integer: sp.Symbol('a', integer=True),
|
||||
MT.Rational: sp.Symbol('a', rational=True),
|
||||
MT.Real: sp.Symbol('a', real=True),
|
||||
MT.Complex: sp.Symbol('a', complex=True),
|
||||
}[self]
|
||||
|
||||
@staticmethod
|
||||
def to_str(value: typ.Self) -> type:
|
||||
return {
|
||||
MathType.Bool: 'T|F',
|
||||
MathType.Integer: 'ℤ',
|
||||
MathType.Rational: 'ℚ',
|
||||
MathType.Real: 'ℝ',
|
||||
|
@ -212,6 +210,9 @@ class MathType(enum.StrEnum):
|
|||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Size: 1D
|
||||
####################
|
||||
class NumberSize1D(enum.StrEnum):
|
||||
"""Valid 1D-constrained shape."""
|
||||
|
||||
|
@ -278,6 +279,20 @@ class NumberSize1D(enum.StrEnum):
|
|||
(4, 1): NS.Vec4,
|
||||
}[shape]
|
||||
|
||||
@property
|
||||
def rows(self):
|
||||
NS = NumberSize1D
|
||||
return {
|
||||
NS.Scalar: 1,
|
||||
NS.Vec2: 2,
|
||||
NS.Vec3: 3,
|
||||
NS.Vec4: 4,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def cols(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
NS = NumberSize1D
|
||||
|
@ -297,6 +312,30 @@ def symbol_range(sym: sp.Symbol) -> str:
|
|||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Symbol Sizes
|
||||
####################
|
||||
class SimpleSize2D(enum.StrEnum):
|
||||
"""Simple subset of sizes for rank-2 tensors."""
|
||||
|
||||
Scalar = enum.auto()
|
||||
|
||||
# Vectors
|
||||
Vec2 = enum.auto() ## 2x1
|
||||
Vec3 = enum.auto() ## 3x1
|
||||
Vec4 = enum.auto() ## 4x1
|
||||
|
||||
# Covectors
|
||||
CoVec2 = enum.auto() ## 1x2
|
||||
CoVec3 = enum.auto() ## 1x3
|
||||
CoVec4 = enum.auto() ## 1x4
|
||||
|
||||
# Square Matrices
|
||||
Mat22 = enum.auto() ## 2x2
|
||||
Mat33 = enum.auto() ## 3x3
|
||||
Mat44 = enum.auto() ## 4x4
|
||||
|
||||
|
||||
####################
|
||||
# - Unit Dimensions
|
||||
####################
|
||||
|
@ -382,6 +421,8 @@ UNIT_BY_SYMBOL: dict[sp.Symbol, spu.Quantity] = {
|
|||
unit.name: unit for unit in spu.__dict__.values() if isinstance(unit, spu.Quantity)
|
||||
} | {unit.name: unit for unit in globals().values() if isinstance(unit, spu.Quantity)}
|
||||
|
||||
UNIT_TO_1: dict[spu.Quantity, 1] = {unit: 1 for unit in UNIT_BY_SYMBOL.values()}
|
||||
|
||||
|
||||
####################
|
||||
# - Expr Analysis: Units
|
||||
|
@ -907,10 +948,6 @@ class PhysicalType(enum.StrEnum):
|
|||
LumIntensity = enum.auto()
|
||||
LumFlux = enum.auto()
|
||||
Illuminance = enum.auto()
|
||||
# Optics
|
||||
OrdinaryWaveVector = enum.auto()
|
||||
AngularWaveVector = enum.auto()
|
||||
PoyntingVector = enum.auto()
|
||||
|
||||
@functools.cached_property
|
||||
def unit_dim(self):
|
||||
|
@ -956,10 +993,6 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.LumIntensity: Dims.luminous_intensity,
|
||||
PT.LumFlux: Dims.luminous_intensity * spu.steradian.dimension,
|
||||
PT.Illuminance: Dims.luminous_intensity / Dims.length**2,
|
||||
# Optics
|
||||
PT.OrdinaryWaveVector: Dims.frequency,
|
||||
PT.AngularWaveVector: Dims.angle * Dims.frequency,
|
||||
PT.PoyntingVector: Dims.power / Dims.length**2,
|
||||
}[self]
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -1196,10 +1229,6 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.HField: [None, (2,), (3,)],
|
||||
# Luminal
|
||||
PT.LumFlux: [None, (2,), (3,)],
|
||||
# Optics
|
||||
PT.OrdinaryWaveVector: [None, (2,), (3,)],
|
||||
PT.AngularWaveVector: [None, (2,), (3,)],
|
||||
PT.PoyntingVector: [None, (2,), (3,)],
|
||||
}
|
||||
|
||||
return overrides.get(self, [None])
|
||||
|
@ -1222,7 +1251,6 @@ class PhysicalType(enum.StrEnum):
|
|||
- **Charge**: Generally, it is real.
|
||||
However, an imaginary phase term seems to have research applications when dealing with high-order harmonics in high-energy pulsed lasers: <https://iopscience.iop.org/article/10.1088/1361-6455/aac787>
|
||||
- **Conductance**: The imaginary part represents the extinction, in the Drude-model sense.
|
||||
- **Poynting**: The imaginary part represents the oscillation in the power flux over time.
|
||||
|
||||
"""
|
||||
MT = MathType
|
||||
|
@ -1249,10 +1277,6 @@ class PhysicalType(enum.StrEnum):
|
|||
PT.EField: [MT.Real, MT.Complex], ## Im -> Phase
|
||||
PT.HField: [MT.Real, MT.Complex], ## Im -> Phase
|
||||
# Luminal
|
||||
# Optics
|
||||
PT.OrdinaryWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
|
||||
PT.AngularWaveVector: [MT.Real, MT.Complex], ## Im -> Phase
|
||||
PT.PoyntingVector: [MT.Real, MT.Complex], ## Im -> Reactive Power
|
||||
}
|
||||
|
||||
return overrides.get(self, [MT.Real])
|
||||
|
@ -1323,10 +1347,6 @@ UNITS_SI: UnitSystem = {
|
|||
_PT.LumIntensity: spu.candela,
|
||||
_PT.LumFlux: lumen,
|
||||
_PT.Illuminance: spu.lux,
|
||||
# Optics
|
||||
_PT.OrdinaryWaveVector: spu.hertz,
|
||||
_PT.AngularWaveVector: spu.radian * spu.hertz,
|
||||
_PT.PoyntingVector: spu.watt / spu.meter**2,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1380,15 +1400,20 @@ def sympy_to_python(
|
|||
####################
|
||||
# - Convert to Unit System
|
||||
####################
|
||||
def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
|
||||
def convert_to_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None
|
||||
) -> SympyExpr:
|
||||
"""Convert an expression to the units of a given unit system, with appropriate scaling."""
|
||||
if unit_system is None:
|
||||
return sp_obj
|
||||
|
||||
return spu.convert_to(
|
||||
sp_obj,
|
||||
{unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)},
|
||||
)
|
||||
|
||||
|
||||
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
|
||||
def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem | None) -> SympyExpr:
|
||||
"""Strip units occurring in the given unit system from the expression.
|
||||
|
||||
Unit stripping is a "dumb" operation: "Substitute any `sympy` object in `unit_system.values()` with `1`".
|
||||
|
@ -1397,11 +1422,13 @@ def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr:
|
|||
Notes:
|
||||
You should probably use `scale_to_unit_system()` or `convert_to_unit_system()`.
|
||||
"""
|
||||
if unit_system is None:
|
||||
return sp_obj.subs(UNIT_TO_1)
|
||||
return sp_obj.subs({unit: 1 for unit in unit_system.values() if unit is not None})
|
||||
|
||||
|
||||
def scale_to_unit_system(
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem, use_jax_array: bool = False
|
||||
sp_obj: SympyExpr, unit_system: UnitSystem | None, use_jax_array: bool = False
|
||||
) -> int | float | complex | tuple | jax.Array:
|
||||
"""Convert an expression to the units of a given unit system, then strip all units of the unit system.
|
||||
|
||||
|
|
|
@ -29,11 +29,13 @@ import matplotlib.axis as mpl_ax
|
|||
import matplotlib.backends.backend_agg
|
||||
import matplotlib.figure
|
||||
import matplotlib.style as mplstyle
|
||||
import seaborn as sns
|
||||
|
||||
from blender_maxwell import contracts as ct
|
||||
from blender_maxwell.utils import logger
|
||||
|
||||
mplstyle.use('fast') ## TODO: Does this do anything?
|
||||
sns.set_theme()
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
@ -149,125 +151,98 @@ def mpl_fig_canvas_ax(width_inches: float, height_inches: float, dpi: int):
|
|||
####################
|
||||
# - Plotters
|
||||
####################
|
||||
# () -> ℝ
|
||||
def plot_hist_1d(
|
||||
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.hist(data, bins=30, alpha=0.75)
|
||||
ax.set_title('Histogram')
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
|
||||
|
||||
# (ℤ) -> ℝ
|
||||
def plot_box_plot_1d(
|
||||
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
|
||||
ax.boxplot(data)
|
||||
ax.set_title('Box Plot')
|
||||
ax.set_xlabel(f'{x_name}')
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
ax.boxplot([data])
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
|
||||
p = ax.bar(info.dims[x_sym], data)
|
||||
ax.bar_label(p, label_type='center')
|
||||
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
# (ℝ) -> ℝ
|
||||
def plot_curve_2d(
|
||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
times = [time.perf_counter()]
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
times.append(time.perf_counter() - times[0])
|
||||
ax.plot(info.dim_idx_arrays[0], data)
|
||||
times.append(time.perf_counter() - times[0])
|
||||
ax.set_title('2D Curve')
|
||||
times.append(time.perf_counter() - times[0])
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
times.append(time.perf_counter() - times[0])
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
times.append(time.perf_counter() - times[0])
|
||||
# log.critical('Timing of Curve2D: %s', str(times))
|
||||
ax.plot(info.dims[x_sym].realize_array.values, data)
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
def plot_points_2d(
|
||||
data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
x_sym = info.last_dim
|
||||
y_sym = info.output
|
||||
|
||||
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
|
||||
ax.set_title('2D Points')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
|
||||
|
||||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
|
||||
ax.set_title('2D Bar')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
ax.scatter(x_sym.realize_array.values, data)
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
|
||||
|
||||
# (ℝ, ℤ) -> ℝ
|
||||
def plot_curves_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size categories'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.output
|
||||
|
||||
for category in range(data.shape[1]):
|
||||
ax.plot(info.dim_idx_arrays[0], data[:, category])
|
||||
for i, category in enumerate(info.dims[info.last_dim]):
|
||||
ax.plot(info.dims[x_sym], data[:, i], label=category)
|
||||
|
||||
ax.set_title('2D Curves')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.legend()
|
||||
|
||||
|
||||
def plot_filled_curves_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size 2'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.output
|
||||
|
||||
shared_x_idx = info.dim_idx_arrays[0]
|
||||
shared_x_idx = info.dims[info.last_dim]
|
||||
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
|
||||
ax.set_title('2D Filled Curves')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
ax.set_title(f'{x_sym.name_pretty} -> {y_sym.name}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.legend()
|
||||
|
||||
|
||||
# (ℝ, ℝ) -> ℝ
|
||||
def plot_heatmap_2d(
|
||||
data: jtyp.Float32[jtyp.Array, 'x_size y_size'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.dim_names[1]
|
||||
y_unit = info.dim_units[y_name]
|
||||
x_sym = info.first_dim
|
||||
y_sym = info.last_dim
|
||||
c_sym = info.output
|
||||
|
||||
heatmap = ax.imshow(data, aspect='auto', interpolation='none')
|
||||
# ax.figure.colorbar(heatmap, ax=ax)
|
||||
ax.set_title('Heatmap')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
heatmap = ax.imshow(data, aspect='equal', interpolation='none')
|
||||
ax.figure.colorbar(heatmap, cax=ax)
|
||||
|
||||
ax.set_title(f'({x_sym.name_pretty}, {y_sym.name_pretty}) -> {c_sym.plot_label}')
|
||||
ax.set_xlabel(x_sym.plot_label)
|
||||
ax.set_xlabel(y_sym.plot_label)
|
||||
ax.legend()
|
||||
|
|
|
@ -18,26 +18,67 @@ import dataclasses
|
|||
import enum
|
||||
import sys
|
||||
import typing as typ
|
||||
from fractions import Fraction
|
||||
|
||||
import sympy as sp
|
||||
|
||||
from . import extra_sympy_units as spux
|
||||
|
||||
int_min = -(2**64)
|
||||
int_max = 2**64
|
||||
float_min = sys.float_info.min
|
||||
float_max = sys.float_info.max
|
||||
|
||||
|
||||
####################
|
||||
# - Simulation Symbols
|
||||
# - Simulation Symbol Names
|
||||
####################
|
||||
class SimSymbolName(enum.StrEnum):
|
||||
# Lower
|
||||
LowerA = enum.auto()
|
||||
LowerB = enum.auto()
|
||||
LowerC = enum.auto()
|
||||
LowerD = enum.auto()
|
||||
LowerI = enum.auto()
|
||||
LowerT = enum.auto()
|
||||
LowerX = enum.auto()
|
||||
LowerY = enum.auto()
|
||||
LowerZ = enum.auto()
|
||||
|
||||
# Physics
|
||||
# Fields
|
||||
Ex = enum.auto()
|
||||
Ey = enum.auto()
|
||||
Ez = enum.auto()
|
||||
Hx = enum.auto()
|
||||
Hy = enum.auto()
|
||||
Hz = enum.auto()
|
||||
|
||||
Er = enum.auto()
|
||||
Etheta = enum.auto()
|
||||
Ephi = enum.auto()
|
||||
Hr = enum.auto()
|
||||
Htheta = enum.auto()
|
||||
Hphi = enum.auto()
|
||||
|
||||
# Optics
|
||||
Wavelength = enum.auto()
|
||||
Frequency = enum.auto()
|
||||
|
||||
Flux = enum.auto()
|
||||
|
||||
PermXX = enum.auto()
|
||||
PermYY = enum.auto()
|
||||
PermZZ = enum.auto()
|
||||
|
||||
DiffOrderX = enum.auto()
|
||||
DiffOrderY = enum.auto()
|
||||
|
||||
# Generic
|
||||
Expr = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
"""Convert the enum value to a human-friendly name.
|
||||
|
@ -50,27 +91,6 @@ class SimSymbolName(enum.StrEnum):
|
|||
"""
|
||||
return SimSymbolName(v).name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
SSN.LowerA: 'a',
|
||||
SSN.LowerT: 't',
|
||||
SSN.LowerX: 'x',
|
||||
SSN.LowerY: 'y',
|
||||
SSN.LowerZ: 'z',
|
||||
SSN.Wavelength: 'wl',
|
||||
SSN.Frequency: 'freq',
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def name_pretty(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
SSN.Wavelength: 'λ',
|
||||
SSN.Frequency: '𝑓',
|
||||
}.get(self, self.name)
|
||||
|
||||
@staticmethod
|
||||
def to_icon(_: typ.Self) -> str:
|
||||
"""Convert the enum value to a Blender icon.
|
||||
|
@ -83,6 +103,75 @@ class SimSymbolName(enum.StrEnum):
|
|||
"""
|
||||
return ''
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def name(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
# Lower
|
||||
SSN.LowerA: 'a',
|
||||
SSN.LowerB: 'b',
|
||||
SSN.LowerC: 'c',
|
||||
SSN.LowerD: 'd',
|
||||
SSN.LowerI: 'i',
|
||||
SSN.LowerT: 't',
|
||||
SSN.LowerX: 'x',
|
||||
SSN.LowerY: 'y',
|
||||
SSN.LowerZ: 'z',
|
||||
# Fields
|
||||
SSN.Ex: 'Ex',
|
||||
SSN.Ey: 'Ey',
|
||||
SSN.Ez: 'Ez',
|
||||
SSN.Hx: 'Hx',
|
||||
SSN.Hy: 'Hy',
|
||||
SSN.Hz: 'Hz',
|
||||
SSN.Er: 'Ex',
|
||||
SSN.Etheta: 'Ey',
|
||||
SSN.Ephi: 'Ez',
|
||||
SSN.Hr: 'Hx',
|
||||
SSN.Htheta: 'Hy',
|
||||
SSN.Hphi: 'Hz',
|
||||
# Optics
|
||||
SSN.Wavelength: 'wl',
|
||||
SSN.Frequency: 'freq',
|
||||
SSN.Flux: 'flux',
|
||||
SSN.PermXX: 'eps_xx',
|
||||
SSN.PermYY: 'eps_yy',
|
||||
SSN.PermZZ: 'eps_zz',
|
||||
SSN.DiffOrderX: 'order_x',
|
||||
SSN.DiffOrderY: 'order_y',
|
||||
# Generic
|
||||
SSN.Expr: 'expr',
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def name_pretty(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
return {
|
||||
SSN.Wavelength: 'λ',
|
||||
SSN.Frequency: '𝑓',
|
||||
}.get(self, self.name)
|
||||
|
||||
|
||||
####################
|
||||
# - Simulation Symbol
|
||||
####################
|
||||
def mk_interval(
|
||||
interval_finite: tuple[int | Fraction | float, int | Fraction | float],
|
||||
interval_inf: tuple[bool, bool],
|
||||
interval_closed: tuple[bool, bool],
|
||||
unit_factor: typ.Literal[1] | spux.Unit,
|
||||
) -> sp.Interval:
|
||||
"""Create a symbolic interval from the tuples (and unit) defining it."""
|
||||
return sp.Interval(
|
||||
start=(interval_finite[0] * unit_factor if not interval_inf[0] else -sp.oo),
|
||||
end=(interval_finite[1] * unit_factor if not interval_inf[1] else sp.oo),
|
||||
left_open=(True if interval_inf[0] else not interval_closed[0]),
|
||||
right_open=(True if interval_inf[1] else not interval_closed[1]),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class SimSymbol:
|
||||
|
@ -94,66 +183,145 @@ class SimSymbol:
|
|||
It's easy to persist, easy to transport, and has many helpful properties which greatly simplify working with symbols.
|
||||
"""
|
||||
|
||||
sim_node_name: SimSymbolName = SimSymbolName.LowerX
|
||||
sym_name: SimSymbolName
|
||||
mathtype: spux.MathType = spux.MathType.Real
|
||||
|
||||
physical_type: spux.PhysicalType = spux.PhysicalType.NonPhysical
|
||||
|
||||
## TODO: Shape/size support? Incl. MatrixSymbol.
|
||||
# Units
|
||||
## -> 'None' indicates that no particular unit has yet been chosen.
|
||||
## -> Not exposed in the UI; must be set some other way.
|
||||
unit: spux.Unit | None = None
|
||||
|
||||
# Domain
|
||||
interval_finite: tuple[float, float] = (0, 1)
|
||||
# Size
|
||||
## -> All SimSymbol sizes are "2D", but interpreted by convention.
|
||||
## -> 1x1: "Scalar".
|
||||
## -> nx1: "Vector".
|
||||
## -> 1xn: "Covector".
|
||||
## -> nxn: "Matrix".
|
||||
rows: int = 1
|
||||
cols: int = 1
|
||||
|
||||
# Scalar Domain: "Interval"
|
||||
## -> NOTE: interval_finite_*[0] must be strictly smaller than [1].
|
||||
## -> See self.domain.
|
||||
## -> We have to deconstruct symbolic interval semantics a bit for UI.
|
||||
interval_finite_z: tuple[int, int] = (0, 1)
|
||||
interval_finite_q: tuple[tuple[int, int], tuple[int, int]] = ((0, 1), (1, 1))
|
||||
interval_finite_re: tuple[float, float] = (0, 1)
|
||||
interval_inf: tuple[bool, bool] = (True, True)
|
||||
interval_closed: tuple[bool, bool] = (False, False)
|
||||
|
||||
interval_finite_im: tuple[float, float] = (0, 1)
|
||||
interval_inf_im: tuple[bool, bool] = (True, True)
|
||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.sim_node_name.name
|
||||
"""Usable name for the symbol."""
|
||||
return self.sym_name.name
|
||||
|
||||
@property
|
||||
def name_pretty(self) -> str:
|
||||
"""Pretty (possibly unicode) name for the thing."""
|
||||
return self.sym_name.name_pretty
|
||||
## TODO: Formatting conventions for bolding/etc. of vectors/mats/...
|
||||
|
||||
@property
|
||||
def plot_label(self) -> str:
|
||||
"""Pretty plot-oriented label."""
|
||||
return f'{self.name_pretty}' + (
|
||||
f'({self.unit})' if self.unit is not None else ''
|
||||
)
|
||||
|
||||
@property
|
||||
def unit_factor(self) -> spux.SympyExpr:
|
||||
"""Factor corresponding to the tracked unit, which can be multiplied onto exported values without `None`-checking."""
|
||||
return self.unit if self.unit is not None else sp.S(1)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
match (self.rows, self.cols):
|
||||
case (1, 1):
|
||||
return ()
|
||||
case (_, 1):
|
||||
return (self.rows,)
|
||||
case (1, _):
|
||||
return (1, self.rows)
|
||||
case (_, _):
|
||||
return (self.rows, self.cols)
|
||||
|
||||
@property
|
||||
def domain(self) -> sp.Interval | sp.Set:
|
||||
"""Return the domain of valid values for the symbol.
|
||||
"""Return the scalar domain of valid values for each element of the symbol.
|
||||
|
||||
For integer/rational/real symbols, the domain is an interval defined using the `interval_*` properties.
|
||||
This interval **must** have the property`start <= stop`.
|
||||
|
||||
Otherwise, the domain is the symbolic set corresponding to `self.mathtype`.
|
||||
"""
|
||||
if self.mathtype in [
|
||||
spux.MathType.Integer,
|
||||
spux.MathType.Rational,
|
||||
spux.MathType.Real,
|
||||
]:
|
||||
return sp.Interval(
|
||||
start=self.interval_finite[0] if not self.interval_inf[0] else -sp.oo,
|
||||
end=self.interval_finite[1] if not self.interval_inf[1] else sp.oo,
|
||||
left_open=(
|
||||
True if self.interval_inf[0] else not self.interval_closed[0]
|
||||
),
|
||||
right_open=(
|
||||
True if self.interval_inf[1] else not self.interval_closed[1]
|
||||
),
|
||||
)
|
||||
match self.mathtype:
|
||||
case spux.MathType.Integer:
|
||||
return mk_interval(
|
||||
self.interval_finite_z,
|
||||
self.interval_inf,
|
||||
self.interval_closed,
|
||||
self.unit_factor,
|
||||
)
|
||||
|
||||
return self.mathtype.symbolic_set
|
||||
case spux.MathType.Rational:
|
||||
return mk_interval(
|
||||
Fraction(*self.interval_finite_q),
|
||||
self.interval_inf,
|
||||
self.interval_closed,
|
||||
self.unit_factor,
|
||||
)
|
||||
|
||||
case spux.MathType.Real:
|
||||
return mk_interval(
|
||||
self.interval_finite_re,
|
||||
self.interval_inf,
|
||||
self.interval_closed,
|
||||
self.unit_factor,
|
||||
)
|
||||
|
||||
case spux.MathType.Complex:
|
||||
return (
|
||||
mk_interval(
|
||||
self.interval_finite_re,
|
||||
self.interval_inf,
|
||||
self.interval_closed,
|
||||
self.unit_factor,
|
||||
),
|
||||
mk_interval(
|
||||
self.interval_finite_im,
|
||||
self.interval_inf_im,
|
||||
self.interval_closed_im,
|
||||
self.unit_factor,
|
||||
),
|
||||
)
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def sp_symbol(self) -> sp.Symbol:
|
||||
"""Return a symbolic variable corresponding to this `SimSymbol`.
|
||||
"""Return a symbolic variable w/unit, corresponding to this `SimSymbol`.
|
||||
|
||||
As much as possible, appropriate `assumptions` are set in the constructor of `sp.Symbol`, insofar as they can be determined.
|
||||
|
||||
However, the assumptions system alone is rather limited, and implementations should therefore also strongly consider transporting `SimSymbols` directly, instead of `sp.Symbol`.
|
||||
This allows making use of other properties like `self.domain`, when appropriate.
|
||||
- **MathType**: Depending on `self.mathtype`.
|
||||
- **Positive/Negative**: Depending on `self.domain`.
|
||||
- **Nonzero**: Depending on `self.domain`, including open/closed boundary specifications.
|
||||
|
||||
Notes:
|
||||
**The assumptions system is rather limited**, and implementations should strongly consider transporting `SimSymbols` instead of `sp.Symbol`.
|
||||
|
||||
This allows tracking ex. the valid interval domain for a symbol.
|
||||
"""
|
||||
# MathType Domain Constraint
|
||||
## -> We must feed the assumptions system.
|
||||
# MathType Assumption
|
||||
mathtype_kwargs = {}
|
||||
match self.mathtype:
|
||||
case spux.MathType.Integer:
|
||||
|
@ -165,53 +333,138 @@ class SimSymbol:
|
|||
case spux.MathType.Complex:
|
||||
mathtype_kwargs |= {'complex': True}
|
||||
|
||||
# Interval Constraints
|
||||
if isinstance(self.domain, sp.Interval):
|
||||
# Assumption: Non-Zero
|
||||
if (
|
||||
(
|
||||
self.domain.left == 0
|
||||
and self.domain.left_open
|
||||
or self.domain.right == 0
|
||||
and self.domain.right_open
|
||||
)
|
||||
or self.domain.left > 0
|
||||
or self.domain.right < 0
|
||||
):
|
||||
mathtype_kwargs |= {'nonzero': True}
|
||||
# Non-Zero Assumption
|
||||
if (
|
||||
(
|
||||
self.domain.left == 0
|
||||
and self.domain.left_open
|
||||
or self.domain.right == 0
|
||||
and self.domain.right_open
|
||||
)
|
||||
or self.domain.left > 0
|
||||
or self.domain.right < 0
|
||||
):
|
||||
mathtype_kwargs |= {'nonzero': True}
|
||||
|
||||
# Assumption: Positive/Negative
|
||||
if self.domain.left >= 0:
|
||||
mathtype_kwargs |= {'positive': True}
|
||||
elif self.domain.right <= 0:
|
||||
mathtype_kwargs |= {'negative': True}
|
||||
# Positive/Negative Assumption
|
||||
if self.domain.left >= 0:
|
||||
mathtype_kwargs |= {'positive': True}
|
||||
elif self.domain.right <= 0:
|
||||
mathtype_kwargs |= {'negative': True}
|
||||
|
||||
# Construct the Symbol
|
||||
return sp.Symbol(self.sim_node_name.name, **mathtype_kwargs)
|
||||
return sp.Symbol(self.sym_name.name, **mathtype_kwargs) * self.unit_factor
|
||||
|
||||
####################
|
||||
# - Operations
|
||||
####################
|
||||
def update(self, **kwargs) -> typ.Self:
|
||||
def get_attr(attr: str):
|
||||
_notfound = 'notfound'
|
||||
if kwargs.get(attr, _notfound) is _notfound:
|
||||
return getattr(self, attr)
|
||||
return kwargs[attr]
|
||||
|
||||
return SimSymbol(
|
||||
sym_name=get_attr('sym_name'),
|
||||
mathtype=get_attr('mathtype'),
|
||||
physical_type=get_attr('physical_type'),
|
||||
unit=get_attr('unit'),
|
||||
rows=get_attr('rows'),
|
||||
cols=get_attr('cols'),
|
||||
interval_finite_z=get_attr('interval_finite_z'),
|
||||
interval_finite_q=get_attr('interval_finite_q'),
|
||||
interval_finite_re=get_attr('interval_finite_q'),
|
||||
interval_inf=get_attr('interval_inf'),
|
||||
interval_closed=get_attr('interval_closed'),
|
||||
interval_finite_im=get_attr('interval_finite_im'),
|
||||
interval_inf_im=get_attr('interval_inf_im'),
|
||||
interval_closed_im=get_attr('interval_closed_im'),
|
||||
)
|
||||
|
||||
def set_size(self, rows: int, cols: int) -> typ.Self:
|
||||
return SimSymbol(
|
||||
sym_name=self.sym_name,
|
||||
mathtype=self.mathtype,
|
||||
physical_type=self.physical_type,
|
||||
unit=self.unit,
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
interval_finite_z=self.interval_finite_z,
|
||||
interval_finite_q=self.interval_finite_q,
|
||||
interval_finite_re=self.interval_finite_re,
|
||||
interval_inf=self.interval_inf,
|
||||
interval_closed=self.interval_closed,
|
||||
interval_finite_im=self.interval_finite_im,
|
||||
interval_inf_im=self.interval_inf_im,
|
||||
interval_closed_im=self.interval_closed_im,
|
||||
)
|
||||
|
||||
|
||||
####################
|
||||
# - Common Sim Symbols
|
||||
####################
|
||||
class CommonSimSymbol(enum.StrEnum):
|
||||
"""A set of pre-defined symbols that might commonly be used in the context of physical simulation.
|
||||
"""Identifiers for commonly used `SimSymbol`s, with all information about ex. `MathType`, `PhysicalType`, and (in general) valid intervals all pre-loaded.
|
||||
|
||||
Each entry maps directly to a particular `SimSymbol`.
|
||||
|
||||
The enum is compatible with `BLField`, making it easy to declare a UI-driven dropdown of symbols that behave as expected.
|
||||
The enum is UI-compatible making it easy to declare a UI-driven dropdown of commonly used symbols that will all behave as expected.
|
||||
|
||||
Attributes:
|
||||
X:
|
||||
Time: A symbol representing a real-valued wavelength.
|
||||
Wavelength: A symbol representing a real-valued wavelength.
|
||||
Implicitly, this symbol often represents "vacuum wavelength" in particular.
|
||||
Wavelength: A symbol representing a real-valued frequency.
|
||||
Generally, this is the non-angular frequency.
|
||||
"""
|
||||
|
||||
X = enum.auto()
|
||||
Index = enum.auto()
|
||||
|
||||
# Space|Time
|
||||
SpaceX = enum.auto()
|
||||
SpaceY = enum.auto()
|
||||
SpaceZ = enum.auto()
|
||||
|
||||
AngR = enum.auto()
|
||||
AngTheta = enum.auto()
|
||||
AngPhi = enum.auto()
|
||||
|
||||
DirX = enum.auto()
|
||||
DirY = enum.auto()
|
||||
DirZ = enum.auto()
|
||||
|
||||
Time = enum.auto()
|
||||
|
||||
# Fields
|
||||
FieldEx = enum.auto()
|
||||
FieldEy = enum.auto()
|
||||
FieldEz = enum.auto()
|
||||
FieldHx = enum.auto()
|
||||
FieldHy = enum.auto()
|
||||
FieldHz = enum.auto()
|
||||
|
||||
FieldEr = enum.auto()
|
||||
FieldEtheta = enum.auto()
|
||||
FieldEphi = enum.auto()
|
||||
FieldHr = enum.auto()
|
||||
FieldHtheta = enum.auto()
|
||||
FieldHphi = enum.auto()
|
||||
|
||||
# Optics
|
||||
Wavelength = enum.auto()
|
||||
Frequency = enum.auto()
|
||||
|
||||
DiffOrderX = enum.auto()
|
||||
DiffOrderY = enum.auto()
|
||||
|
||||
Flux = enum.auto()
|
||||
|
||||
WaveVecX = enum.auto()
|
||||
WaveVecY = enum.auto()
|
||||
WaveVecZ = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
"""Convert the enum value to a human-friendly name.
|
||||
|
@ -222,7 +475,7 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
Returns:
|
||||
A human-friendly name corresponding to the enum value.
|
||||
"""
|
||||
return CommonSimSymbol(v).sim_symbol_name.name
|
||||
return CommonSimSymbol(v).name
|
||||
|
||||
@staticmethod
|
||||
def to_icon(_: typ.Self) -> str:
|
||||
|
@ -241,55 +494,125 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
####################
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.sim_symbol.name
|
||||
|
||||
@property
|
||||
def sim_symbol_name(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
CSS = CommonSimSymbol
|
||||
return {
|
||||
CSS.X: SSN.LowerX,
|
||||
CSS.Index: SSN.LowerI,
|
||||
# Space|Time
|
||||
CSS.SpaceX: SSN.LowerX,
|
||||
CSS.SpaceY: SSN.LowerY,
|
||||
CSS.SpaceZ: SSN.LowerZ,
|
||||
CSS.AngR: SSN.LowerR,
|
||||
CSS.AngTheta: SSN.LowerTheta,
|
||||
CSS.AngPhi: SSN.LowerPhi,
|
||||
CSS.DirX: SSN.LowerX,
|
||||
CSS.DirY: SSN.LowerY,
|
||||
CSS.DirZ: SSN.LowerZ,
|
||||
CSS.Time: SSN.LowerT,
|
||||
CSS.Wavelength: SSN.Wavelength,
|
||||
# Fields
|
||||
CSS.FieldEx: SSN.Ex,
|
||||
CSS.FieldEy: SSN.Ey,
|
||||
CSS.FieldEz: SSN.Ez,
|
||||
CSS.FieldHx: SSN.Hx,
|
||||
CSS.FieldHy: SSN.Hy,
|
||||
CSS.FieldHz: SSN.Hz,
|
||||
CSS.FieldEr: SSN.Er,
|
||||
CSS.FieldHr: SSN.Hr,
|
||||
# Optics
|
||||
CSS.Frequency: SSN.Frequency,
|
||||
CSS.Wavelength: SSN.Wavelength,
|
||||
CSS.DiffOrderX: SSN.DiffOrderX,
|
||||
CSS.DiffOrderY: SSN.DiffOrderY,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def sim_symbol(self) -> SimSymbol:
|
||||
def sim_symbol(self, unit: spux.Unit | None) -> SimSymbol:
|
||||
"""Retrieve the `SimSymbol` associated with the `CommonSimSymbol`."""
|
||||
CSS = CommonSimSymbol
|
||||
|
||||
# Space
|
||||
sym_space = SimSymbol(
|
||||
sym_name=self.name,
|
||||
physical_type=spux.PhysicalType.Length,
|
||||
unit=unit,
|
||||
)
|
||||
sym_ang = SimSymbol(
|
||||
sym_name=self.name,
|
||||
physical_type=spux.PhysicalType.Angle,
|
||||
unit=unit,
|
||||
)
|
||||
|
||||
# Fields
|
||||
def sym_field(eh: typ.Literal['e', 'h']) -> SimSymbol:
|
||||
return SimSymbol(
|
||||
sym_name=self.name,
|
||||
physical_type=spux.PhysicalType.EField
|
||||
if eh == 'e'
|
||||
else spux.PhysicalType.HField,
|
||||
unit=unit,
|
||||
interval_finite_re=(0, sys.float_info.max),
|
||||
interval_inf_re=(False, True),
|
||||
interval_closed_re=(True, False),
|
||||
interval_finite_im=(sys.float_info.min, sys.float_info.max),
|
||||
interval_inf_im=(True, True),
|
||||
)
|
||||
|
||||
return {
|
||||
CSS.X: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.NonPhysical,
|
||||
## TODO: Unit of Picosecond
|
||||
interval_finite=(sys.float_info.min, sys.float_info.max),
|
||||
interval_inf=(True, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.Time: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
## TODO: Unit of Picosecond
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
CSS.Index: SimSymbol(
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Integer,
|
||||
interval_finite_z=(0, 2**64),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
# Space|Time
|
||||
CSS.SpaceX: sym_space,
|
||||
CSS.SpaceY: sym_space,
|
||||
CSS.SpaceZ: sym_space,
|
||||
CSS.AngR: sym_space,
|
||||
CSS.AngTheta: sym_ang,
|
||||
CSS.AngPhi: sym_ang,
|
||||
CSS.Time: SimSymbol(
|
||||
sym_name=self.name,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
unit=unit,
|
||||
interval_finite_re=(0, sys.float_info.max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
# Fields
|
||||
CSS.FieldEx: sym_field('e'),
|
||||
CSS.FieldEy: sym_field('e'),
|
||||
CSS.FieldEz: sym_field('e'),
|
||||
CSS.FieldHx: sym_field('h'),
|
||||
CSS.FieldHy: sym_field('h'),
|
||||
CSS.FieldHz: sym_field('h'),
|
||||
CSS.FieldEr: sym_field('e'),
|
||||
CSS.FieldEtheta: sym_field('e'),
|
||||
CSS.FieldEphi: sym_field('e'),
|
||||
CSS.FieldHr: sym_field('h'),
|
||||
CSS.FieldHtheta: sym_field('h'),
|
||||
CSS.FieldHphi: sym_field('h'),
|
||||
CSS.Flux: SimSymbol(
|
||||
sym_name=SimSymbolName.Flux,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Power,
|
||||
unit=unit,
|
||||
),
|
||||
# Optics
|
||||
CSS.Wavelength: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Length,
|
||||
## TODO: Unit of Picosecond
|
||||
unit=unit,
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.Frequency: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
sym_name=self.name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
unit=unit,
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
|
@ -298,9 +621,33 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
|
||||
|
||||
####################
|
||||
# - Selected Direct Access
|
||||
# - Selected Direct-Access to SimSymbols
|
||||
####################
|
||||
x = CommonSimSymbol.X.sim_symbol
|
||||
idx = CommonSimSymbol.Index.sim_symbol
|
||||
t = CommonSimSymbol.Time.sim_symbol
|
||||
wl = CommonSimSymbol.Wavelength.sim_symbol
|
||||
freq = CommonSimSymbol.Frequency.sim_symbol
|
||||
|
||||
space_x = CommonSimSymbol.SpaceX.sim_symbol
|
||||
space_y = CommonSimSymbol.SpaceY.sim_symbol
|
||||
space_z = CommonSimSymbol.SpaceZ.sim_symbol
|
||||
|
||||
dir_x = CommonSimSymbol.DirX.sim_symbol
|
||||
dir_y = CommonSimSymbol.DirY.sim_symbol
|
||||
dir_z = CommonSimSymbol.DirZ.sim_symbol
|
||||
|
||||
ang_r = CommonSimSymbol.AngR.sim_symbol
|
||||
ang_theta = CommonSimSymbol.AngTheta.sim_symbol
|
||||
ang_phi = CommonSimSymbol.AngPhi.sim_symbol
|
||||
|
||||
field_ex = CommonSimSymbol.FieldEx.sim_symbol
|
||||
field_ey = CommonSimSymbol.FieldEy.sim_symbol
|
||||
field_ez = CommonSimSymbol.FieldEz.sim_symbol
|
||||
field_hx = CommonSimSymbol.FieldHx.sim_symbol
|
||||
field_hy = CommonSimSymbol.FieldHx.sim_symbol
|
||||
field_hz = CommonSimSymbol.FieldHx.sim_symbol
|
||||
|
||||
flux = CommonSimSymbol.Flux.sim_symbol
|
||||
|
||||
diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
|
||||
diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol
|
||||
|
|
Loading…
Reference in New Issue