feat: transform node w/sane `DataChanged`-chaining
We also implemented a JIT-based `rescale` for `ArrayFlow`, where the same function as for `LazyArrayRangeFlow` passes through and can do an arbitrary symbolic shift/rescale/order-preserving whatever. To make this fast for `ArrayFlow`, we utilize a "test" variable `a`, put it through the function and rescale/strip its units, then `lambdify` it and broadcast the function onto `ArrayFlow.values`. It's an immensely clean way of working. The `lambdify` seems fast, interestingly, but ideally we would of course also cache that somehow. Some details remain: - Fourier Transform index bounds / length are currently not presumed known before actually computing them; it's unclear what's best here, as the design cross-section of physical expectation, mathematical correctness, and ease of implementation (especially trying to keep actual data out of the `InfoFlow` hot-path). For now, the `0..\infty` bounds **will probably break** the `Viz` node. - We've reverted the notion that particular `sim_symbol`s must have a unit pre-defined, which causes a little more complexity for nodes like `TemporalShape`. This question is going to need resolving - The whole `InfoFlow` object really seems to be at the heart of certain lagginess when it comes to the math system. It, together with the index representations, would benefit greatly from a principled refactor. - The `Viewer` node is broken for 3D preview; see #70. Closes #59. Progress on #54.main
parent
39747e2d68
commit
a66a28da27
18
TODO.md
18
TODO.md
|
@ -2,15 +2,15 @@
|
|||
- [x] Wave Constant
|
||||
- Sources
|
||||
- [x] Temporal Shapes / Continuous Wave Temporal Shape
|
||||
- [ ] Temporal Shapes / Symbolic Temporal Shape
|
||||
- [ ] Plane Wave Source
|
||||
- [x] Temporal Shapes / Symbolic Temporal Shape
|
||||
- [x] Plane Wave Source
|
||||
- [ ] TFSF Source
|
||||
- [ ] Gaussian Beam Source
|
||||
- [x] Gaussian Beam Source
|
||||
- [ ] Astig. Gauss Beam
|
||||
- Monitors
|
||||
- [x] EH Field
|
||||
- [x] Power Flux
|
||||
- [ ] Permittivity
|
||||
- [x] Permittivity
|
||||
- [ ] Diffraction
|
||||
- Tidy3D / Integration
|
||||
- [ ] Exporter
|
||||
|
@ -23,7 +23,7 @@
|
|||
- [ ] Uniform
|
||||
- [ ] Data
|
||||
- Structures
|
||||
- [ ] Cylinder
|
||||
- [x] Cylinder
|
||||
- [ ] Cylinder Array
|
||||
- [ ] L-Cavity Cylinder
|
||||
- [ ] H-Cavity Cylinder
|
||||
|
@ -31,10 +31,10 @@
|
|||
- [ ] BCC Lattice
|
||||
- [ ] Monkey
|
||||
- Expr Socket
|
||||
- [ ] Array Mode
|
||||
- [x] LVF Mode
|
||||
- Math Nodes
|
||||
- [ ] Reduce Math
|
||||
- [ ] Transform Math - reindex freq->wl
|
||||
- [x] Transform Math - reindex freq->wl
|
||||
- Material Data Fitting
|
||||
- [ ] Data File Import
|
||||
- [ ] DataFit Medium
|
||||
|
@ -47,10 +47,10 @@
|
|||
- [ ] Debye Medium
|
||||
- [ ] Anisotropic Medium
|
||||
- Integration
|
||||
- [ ] Simulation and Analysis of Maxim's Cavity
|
||||
- [x] Simulation and Analysis of Maxim's Cavity
|
||||
- Constants
|
||||
- [x] Number Constant
|
||||
- [x] Vector Constant
|
||||
- [x] Physical Constant
|
||||
|
||||
- [ ] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`.
|
||||
- [x] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`.
|
||||
|
|
|
@ -20,6 +20,7 @@ import typing as typ
|
|||
|
||||
import jaxtyping as jtyp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
@ -110,3 +111,24 @@ class ArrayFlow:
|
|||
|
||||
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
|
||||
raise ValueError(msg)
|
||||
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
# Compile JAX-Compatible Rescale Function
|
||||
a = sp.Symbol('a')
|
||||
rescale_expr = (
|
||||
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
|
||||
if self.unit is not None
|
||||
else rescale_func(a * self.unit)
|
||||
)
|
||||
log.critical([self.unit, new_unit, rescale_expr])
|
||||
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
|
||||
values = _rescale_func(self.values)
|
||||
|
||||
# Return ArrayFlow
|
||||
return ArrayFlow(
|
||||
values=values[::-1] if reverse else values,
|
||||
unit=new_unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
|
|
@ -31,7 +31,9 @@ log = logger.get(__name__)
|
|||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class InfoFlow:
|
||||
# Dimension Information
|
||||
####################
|
||||
# - Covariant Input
|
||||
####################
|
||||
dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
|
||||
default_factory=dict
|
||||
|
@ -67,6 +69,9 @@ class InfoFlow:
|
|||
for dim_idx in self.dim_idx.values()
|
||||
]
|
||||
|
||||
####################
|
||||
# - Contravariant Output
|
||||
####################
|
||||
# Output Information
|
||||
## TODO: Add PhysicalType
|
||||
output_name: str = dataclasses.field(default_factory=list)
|
||||
|
@ -94,6 +99,28 @@ class InfoFlow:
|
|||
####################
|
||||
# - Methods
|
||||
####################
|
||||
def replace_dim(
|
||||
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow]
|
||||
) -> typ.Self:
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=[
|
||||
dim_name if dim_name != old_dim_name else new_dim_idx[0]
|
||||
for dim_name in self.dim_names
|
||||
],
|
||||
dim_idx={
|
||||
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
|
||||
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
|
||||
)
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
||||
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
|
|
|
@ -268,6 +268,32 @@ class LazyArrayRangeFlow:
|
|||
####################
|
||||
# - Bound Operations
|
||||
####################
|
||||
def rescale(
|
||||
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
|
||||
) -> typ.Self:
|
||||
new_pre_start = self.start if not reverse else self.stop
|
||||
new_pre_stop = self.stop if not reverse else self.start
|
||||
|
||||
new_start = rescale_func(new_pre_start * self.unit)
|
||||
new_stop = rescale_func(new_pre_stop * self.unit)
|
||||
|
||||
return LazyArrayRangeFlow(
|
||||
start=(
|
||||
spux.scale_to_unit(new_start, new_unit)
|
||||
if new_unit is not None
|
||||
else new_start
|
||||
),
|
||||
stop=(
|
||||
spux.scale_to_unit(new_stop, new_unit)
|
||||
if new_unit is not None
|
||||
else new_stop
|
||||
),
|
||||
steps=self.steps,
|
||||
scaling=self.scaling,
|
||||
unit=new_unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
def rescale_bounds(
|
||||
self,
|
||||
rescale_func: typ.Callable[
|
||||
|
|
|
@ -566,7 +566,9 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
}
|
||||
| {
|
||||
c: ct.ArrayFlow(
|
||||
values=xarr.get_index(c).values, unit=spu.radian, is_sorted=True
|
||||
values=xarr.get_index(c).values,
|
||||
unit=spu.radian,
|
||||
is_sorted=True,
|
||||
)
|
||||
for c in ['r', 'theta', 'phi']
|
||||
}
|
||||
|
|
|
@ -14,19 +14,19 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from . import filter_math, map_math, operate_math # , #reduce_math, transform_math
|
||||
from . import filter_math, map_math, operate_math, transform_math
|
||||
|
||||
BL_REGISTER = [
|
||||
*operate_math.BL_REGISTER,
|
||||
*map_math.BL_REGISTER,
|
||||
*filter_math.BL_REGISTER,
|
||||
# *reduce_math.BL_REGISTER,
|
||||
# *transform_math.BL_REGISTER,
|
||||
*transform_math.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**operate_math.BL_NODES,
|
||||
**map_math.BL_NODES,
|
||||
**filter_math.BL_NODES,
|
||||
# **reduce_math.BL_NODES,
|
||||
# **transform_math.BL_NODES,
|
||||
**transform_math.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ class FilterOperation(enum.StrEnum):
|
|||
Pin = enum.auto()
|
||||
Swap = enum.auto()
|
||||
|
||||
# Interpret
|
||||
# Fold
|
||||
DimToVec = enum.auto()
|
||||
DimsToMat = enum.auto()
|
||||
|
||||
|
|
|
@ -439,7 +439,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_output_shape(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info.output_shape
|
||||
|
@ -491,9 +491,9 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_value and operation is not None:
|
||||
operation.sp_func(expr)
|
||||
return operation.sp_func(expr)
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
|
@ -529,7 +529,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
operation = props['operation']
|
||||
info = input_sockets['Expr']
|
||||
|
||||
|
@ -546,8 +546,11 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
if has_params:
|
||||
return input_sockets['Expr']
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -20,9 +20,12 @@ import enum
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
|
@ -31,6 +34,164 @@ from ... import base, events
|
|||
log = logger.get(__name__)
|
||||
|
||||
|
||||
####################
|
||||
# - Operation Enum
|
||||
####################
|
||||
class TransformOperation(enum.StrEnum):
|
||||
"""Valid operations for the `MapMathNode`.
|
||||
|
||||
Attributes:
|
||||
FreqToVacWL: Transform frequency axes to be indexed by vacuum wavelength.
|
||||
VacWLToFreq: Transform vacuum wavelength axes to be indexed by frequency.
|
||||
FFT: Compute the fourier transform of the input expression.
|
||||
InvFFT: Compute the inverse fourier transform of the input expression.
|
||||
"""
|
||||
|
||||
# Index
|
||||
FreqToVacWL = enum.auto()
|
||||
VacWLToFreq = enum.auto()
|
||||
# Fourier
|
||||
FFT1D = enum.auto()
|
||||
InvFFT1D = enum.auto()
|
||||
# Affine
|
||||
## TODO
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# By Number
|
||||
TO.FreqToVacWL: '𝑓 → λᵥ',
|
||||
TO.VacWLToFreq: 'λᵥ → 𝑓',
|
||||
TO.FFT1D: 't → 𝑓',
|
||||
TO.InvFFT1D: '𝑓 → t',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
||||
TO = TransformOperation
|
||||
return (
|
||||
str(self),
|
||||
TO.to_name(self),
|
||||
TO.to_name(self),
|
||||
TO.to_icon(self),
|
||||
i,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Ops from Shape
|
||||
####################
|
||||
@staticmethod
|
||||
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]:
|
||||
TO = TransformOperation
|
||||
operations = []
|
||||
|
||||
# Freq <-> VacWL
|
||||
for dim_name in info.dim_names:
|
||||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
operations.append(TO.FreqToVacWL)
|
||||
|
||||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
operations.append(TO.VacWLToFreq)
|
||||
|
||||
# 1D Fourier
|
||||
if info.dim_names:
|
||||
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
|
||||
if last_physical_type == spux.PhysicalType.Time:
|
||||
operations.append(TO.FFT1D)
|
||||
if last_physical_type == spux.PhysicalType.Freq:
|
||||
operations.append(TO.InvFFT1D)
|
||||
|
||||
return operations
|
||||
|
||||
####################
|
||||
# - Function Properties
|
||||
####################
|
||||
@property
|
||||
def sp_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Index
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: sp.fourier_transform(
|
||||
expr, sim_symbols.t, sim_symbols.freq
|
||||
),
|
||||
TO.InvFFT1D: lambda expr: sp.fourier_transform(
|
||||
expr, sim_symbols.freq, sim_symbols.t
|
||||
),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def jax_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Index
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: jnp.fft(expr),
|
||||
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
|
||||
}[self]
|
||||
|
||||
def transform_info(self, info: ct.InfoFlow | None) -> ct.InfoFlow | None:
|
||||
TO = TransformOperation
|
||||
if not info.dim_names:
|
||||
return None
|
||||
return {
|
||||
# Index
|
||||
TO.FreqToVacWL: lambda: info.replace_dim(
|
||||
(f_dim := info.dim_names[-1]),
|
||||
[
|
||||
'wl',
|
||||
info.dim_idx[f_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spu.nanometer,
|
||||
),
|
||||
],
|
||||
),
|
||||
TO.VacWLToFreq: lambda: info.replace_dim(
|
||||
(wl_dim := info.dim_names[-1]),
|
||||
[
|
||||
'f',
|
||||
info.dim_idx[wl_dim].rescale(
|
||||
lambda el: sci_constants.vac_speed_of_light / el,
|
||||
reverse=True,
|
||||
new_unit=spux.THz,
|
||||
),
|
||||
],
|
||||
),
|
||||
# Fourier
|
||||
TO.FFT1D: lambda: info.replace_dim(
|
||||
info.dim_names[-1],
|
||||
[
|
||||
'f',
|
||||
ct.LazyArrayRangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz),
|
||||
],
|
||||
),
|
||||
TO.InvFFT1D: info.replace_dim(
|
||||
info.dim_names[-1],
|
||||
[
|
||||
't',
|
||||
ct.LazyArrayRangeFlow(
|
||||
start=0, stop=sp.oo, steps=0, unit=spu.second
|
||||
),
|
||||
],
|
||||
),
|
||||
}.get(self, lambda: info)()
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class TransformMathNode(base.MaxwellSimNode):
|
||||
r"""Applies a function to the array as a whole, with arbitrary results.
|
||||
|
||||
|
@ -48,125 +209,153 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
bl_label = 'Transform Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Fourier': {},
|
||||
'Affine': {},
|
||||
'Convolve': {},
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations()
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
input_sockets_optional={'Expr': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
info_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
if has_info and not info_pending:
|
||||
self.expr_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info
|
||||
|
||||
return None
|
||||
|
||||
operation: TransformOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.active_socket_set == 'Fourier': # noqa: SIM114
|
||||
items = []
|
||||
elif self.active_socket_set == 'Affine': # noqa: SIM114
|
||||
items = []
|
||||
elif self.active_socket_set == 'Convolve':
|
||||
items = []
|
||||
else:
|
||||
msg = f'Active socket set {self.active_socket_set} is unknown'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return [(*item, '', i) for i, item in enumerate(items)]
|
||||
if self.expr_info is not None:
|
||||
return [
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(
|
||||
TransformOperation.by_element_shape(self.expr_info)
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
# - UI
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
prop_name='active_socket_set',
|
||||
)
|
||||
def on_socket_set_changed(self):
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
def draw_label(self):
|
||||
if self.operation is not None:
|
||||
return 'Transform: ' + TransformOperation.to_name(self.operation)
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Compute: LazyValueFunc / Array
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Value,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
)
|
||||
def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
has_expr_value = not ct.FlowSignal.check(expr)
|
||||
|
||||
# Compute Sympy Function
|
||||
## -> The operation enum directly provides the appropriate function.
|
||||
if has_expr_value and operation is not None:
|
||||
return operation.sp_func(expr)
|
||||
|
||||
return ct.Flowsignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data'},
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Data': ct.FlowKind.LazyValueFunc,
|
||||
'Expr': ct.FlowKind.LazyValueFunc,
|
||||
},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||
if not has_data or props['operation'] == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
def compute_func(
|
||||
self, props, input_sockets
|
||||
) -> ct.LazyValueFuncFlow | ct.FlowSignal:
|
||||
operation = props['operation']
|
||||
expr = input_sockets['Expr']
|
||||
|
||||
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
|
||||
'Fourier': {},
|
||||
'Affine': {},
|
||||
'Convolve': {},
|
||||
}[props['active_socket_set']][props['operation']]
|
||||
has_expr = not ct.FlowSignal.check(expr)
|
||||
|
||||
# Compose w/Lazy Root Function Data
|
||||
return input_sockets['Data'].compose_within(
|
||||
mapping_func,
|
||||
if has_expr and operation is not None:
|
||||
return expr.compose_within(
|
||||
operation.jax_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Data'},
|
||||
output_socket_kinds={
|
||||
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
)
|
||||
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||
|
||||
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.func_args, **params.func_kwargs
|
||||
),
|
||||
unit=None,
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Compute Auxiliary: Info / Params
|
||||
# - FlowKind.Info|Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||
props={'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
info = input_sockets['Data']
|
||||
if ct.FlowSignal.check(info):
|
||||
def compute_info(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
|
||||
operation = props['operation']
|
||||
info = input_sockets['Expr']
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
|
||||
if has_info and operation is not None:
|
||||
transformed_info = operation.transform_info(info)
|
||||
if transformed_info is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
return transformed_info
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
return info
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Params},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
return input_sockets['Data']
|
||||
def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
has_params = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
if has_params:
|
||||
return input_sockets['Expr']
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -225,8 +225,24 @@ class VizNode(base.MaxwellSimNode):
|
|||
#####################
|
||||
## - Properties
|
||||
#####################
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
input_sockets_optional={'Expr': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
info_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
if has_info and not info_pending:
|
||||
self.expr_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def input_info(self) -> ct.InfoFlow | None:
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||
if not ct.FlowSignal.check(info):
|
||||
return info
|
||||
|
@ -235,7 +251,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
|
||||
viz_mode: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_viz_modes(),
|
||||
cb_depends_on={'input_info'},
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
viz_target: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_targets(),
|
||||
|
@ -251,7 +267,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
## - Searchers
|
||||
#####################
|
||||
def search_viz_modes(self) -> list[ct.BLEnumElement]:
|
||||
if self.input_info is not None:
|
||||
if self.expr_info is not None:
|
||||
return [
|
||||
(
|
||||
viz_mode,
|
||||
|
@ -260,7 +276,7 @@ class VizNode(base.MaxwellSimNode):
|
|||
VizMode.to_icon(viz_mode),
|
||||
i,
|
||||
)
|
||||
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.input_info))
|
||||
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info))
|
||||
]
|
||||
|
||||
return []
|
||||
|
@ -284,6 +300,12 @@ class VizNode(base.MaxwellSimNode):
|
|||
#####################
|
||||
## - UI
|
||||
#####################
|
||||
def draw_label(self):
|
||||
if self.viz_mode is not None:
|
||||
return 'Viz: ' + self.sim_node_name
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
|
||||
col.prop(self, self.blfields['viz_mode'], text='')
|
||||
col.prop(self, self.blfields['viz_target'], text='')
|
||||
|
@ -338,11 +360,23 @@ class VizNode(base.MaxwellSimNode):
|
|||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
self.input_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
#####################
|
||||
## - Plotting
|
||||
#####################
|
||||
@events.computes_output_socket(
|
||||
'Preview',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params}
|
||||
},
|
||||
all_loose_input_sockets=True,
|
||||
)
|
||||
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
|
||||
return ct.FlowSignal.NoFlow
|
||||
|
||||
@events.on_show_plot(
|
||||
managed_objs={'plot'},
|
||||
props={'viz_mode', 'viz_target', 'colormap'},
|
||||
|
|
|
@ -787,7 +787,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
altered_socket_kinds = set()
|
||||
|
||||
# Invalidate Caches on DataChanged
|
||||
if event == ct.FlowEvent.DataChanged:
|
||||
if event is ct.FlowEvent.DataChanged:
|
||||
input_socket_name = socket_name ## Trigger direction is forwards
|
||||
|
||||
# Invalidate Input Socket Cache
|
||||
|
@ -861,8 +861,18 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
# )
|
||||
event_method(self)
|
||||
|
||||
# DataChanged Propagation Stop: No Altered Socket Kinds
|
||||
## -> If no FlowKinds were altered, then propagation makes no sense.
|
||||
## -> Semantically, **nothing has changed** == no DataChanged!
|
||||
if event is ct.FlowEvent.DataChanged and not altered_socket_kinds:
|
||||
return
|
||||
|
||||
# Constrain ShowPlot to First Node: Workaround
|
||||
if event is ct.FlowEvent.ShowPlot:
|
||||
return
|
||||
|
||||
# Propagate Event to All Sockets in "Trigger Direction"
|
||||
## The trigger chain goes node/socket/node/socket/...
|
||||
## -> The trigger chain goes node/socket/socket/node/socket/...
|
||||
if not stop_propagation:
|
||||
direc = ct.FlowEvent.flow_direction[event]
|
||||
triggered_sockets = self._bl_sockets(direc=direc)
|
||||
|
|
|
@ -139,8 +139,10 @@ class ViewerNode(base.MaxwellSimNode):
|
|||
|
||||
# Unset Plot if Nothing Plotted
|
||||
with node_tree.replot():
|
||||
if props['auto_plot']:
|
||||
self.trigger_event(ct.FlowEvent.ShowPlot)
|
||||
if props['auto_plot'] and self.inputs['Any'].is_linked:
|
||||
self.inputs['Any'].links[0].from_socket.node.trigger_event(
|
||||
ct.FlowEvent.ShowPlot
|
||||
)
|
||||
|
||||
@events.on_value_changed(
|
||||
socket_name='Any',
|
||||
|
|
|
@ -82,8 +82,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
|
|||
default_steps=100,
|
||||
),
|
||||
'Envelope': sockets.ExprSocketDef(
|
||||
default_symbols=[sim_symbols.t_ps],
|
||||
default_value=10 * sim_symbols.t_ps.sp_symbol,
|
||||
default_symbols=[sim_symbols.t],
|
||||
default_value=10 * sim_symbols.t.sp_symbol,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -251,7 +251,8 @@ class BLInstance:
|
|||
if prop_name in deps:
|
||||
for dst_prop_name in deps[prop_name]:
|
||||
log.debug(
|
||||
'Property %s is invalidating %s',
|
||||
'%s: "%s" is invalidating "%s"',
|
||||
self.bl_label,
|
||||
prop_name,
|
||||
dst_prop_name,
|
||||
)
|
||||
|
|
|
@ -204,9 +204,9 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
"""
|
||||
|
||||
X = enum.auto()
|
||||
TimePS = enum.auto()
|
||||
WavelengthNM = enum.auto()
|
||||
FrequencyTHZ = enum.auto()
|
||||
Time = enum.auto()
|
||||
Wavelength = enum.auto()
|
||||
Frequency = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
|
@ -241,9 +241,9 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
CSS = CommonSimSymbol
|
||||
return {
|
||||
CSS.X: SSN.LowerX,
|
||||
CSS.TimePS: SSN.LowerT,
|
||||
CSS.WavelengthNM: SSN.Wavelength,
|
||||
CSS.FrequencyTHZ: SSN.Frequency,
|
||||
CSS.Time: SSN.LowerT,
|
||||
CSS.Wavelength: SSN.Wavelength,
|
||||
CSS.Frequency: SSN.Frequency,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
|
@ -260,7 +260,7 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
interval_inf=(True, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.TimePS: SimSymbol(
|
||||
CSS.Time: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Time,
|
||||
|
@ -269,7 +269,7 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
interval_inf=(False, True),
|
||||
interval_closed=(True, False),
|
||||
),
|
||||
CSS.WavelengthNM: SimSymbol(
|
||||
CSS.Wavelength: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Length,
|
||||
|
@ -278,11 +278,10 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
),
|
||||
CSS.FrequencyTHZ: SimSymbol(
|
||||
CSS.Frequency: SimSymbol(
|
||||
sim_node_name=self.sim_symbol_name,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.Freq,
|
||||
## TODO: Unit of THz
|
||||
interval_finite=(0, sys.float_info.max),
|
||||
interval_inf=(False, True),
|
||||
interval_closed=(False, False),
|
||||
|
@ -294,6 +293,6 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
# - Selected Direct Access
|
||||
####################
|
||||
x = CommonSimSymbol.X.sim_symbol
|
||||
t_ps = CommonSimSymbol.TimePS.sim_symbol
|
||||
wl_nm = CommonSimSymbol.WavelengthNM.sim_symbol
|
||||
freq_thz = CommonSimSymbol.FrequencyTHZ.sim_symbol
|
||||
t = CommonSimSymbol.Time.sim_symbol
|
||||
wl = CommonSimSymbol.Wavelength.sim_symbol
|
||||
freq = CommonSimSymbol.Frequency.sim_symbol
|
||||
|
|
Loading…
Reference in New Issue