feat: transform node w/sane `DataChanged`-chaining

We also implemented a JIT-based `rescale` for `ArrayFlow`, where the
same function as for `LazyArrayRangeFlow` passes through and can do an
arbitrary symbolic shift/rescale/order-preserving whatever.
To make this fast for `ArrayFlow`, we utilize a "test" variable `a`,
put it through the function and rescale/strip its units, then `lambdify`
it and broadcast the function onto `ArrayFlow.values`.

It's an immensely clean way of working.
The `lambdify` seems fast, interestingly, but ideally we would of course
also cache that somehow.

Some details remain:

- Fourier Transform index bounds / length are currently not presumed
  known before actually computing them; it's unclear what's best here,
  as the design cross-section of physical expectation, mathematical
  correctness, and ease of implementation (especially trying to keep
  actual data out of the `InfoFlow` hot-path). For now, the `0..\infty`
  bounds **will probably break** the `Viz` node.
- We've reverted the notion that particular `sim_symbol`s must have a
  unit pre-defined, which causes a little more complexity for nodes like
  `TemporalShape`. This question is going to need resolving
- The whole `InfoFlow` object really seems to be at the heart of certain
  lagginess when it comes to the math system. It, together with the
  index representations, would benefit greatly from a principled
  refactor.
- The `Viewer` node is broken for 3D preview; see #70.

Closes #59. Progress on #54.
main
Sofus Albert Høgsbro Rose 2024-05-19 07:20:23 +02:00
parent 39747e2d68
commit a66a28da27
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
15 changed files with 441 additions and 126 deletions

18
TODO.md
View File

@ -2,15 +2,15 @@
- [x] Wave Constant - [x] Wave Constant
- Sources - Sources
- [x] Temporal Shapes / Continuous Wave Temporal Shape - [x] Temporal Shapes / Continuous Wave Temporal Shape
- [ ] Temporal Shapes / Symbolic Temporal Shape - [x] Temporal Shapes / Symbolic Temporal Shape
- [ ] Plane Wave Source - [x] Plane Wave Source
- [ ] TFSF Source - [ ] TFSF Source
- [ ] Gaussian Beam Source - [x] Gaussian Beam Source
- [ ] Astig. Gauss Beam - [ ] Astig. Gauss Beam
- Monitors - Monitors
- [x] EH Field - [x] EH Field
- [x] Power Flux - [x] Power Flux
- [ ] Permittivity - [x] Permittivity
- [ ] Diffraction - [ ] Diffraction
- Tidy3D / Integration - Tidy3D / Integration
- [ ] Exporter - [ ] Exporter
@ -23,7 +23,7 @@
- [ ] Uniform - [ ] Uniform
- [ ] Data - [ ] Data
- Structures - Structures
- [ ] Cylinder - [x] Cylinder
- [ ] Cylinder Array - [ ] Cylinder Array
- [ ] L-Cavity Cylinder - [ ] L-Cavity Cylinder
- [ ] H-Cavity Cylinder - [ ] H-Cavity Cylinder
@ -31,10 +31,10 @@
- [ ] BCC Lattice - [ ] BCC Lattice
- [ ] Monkey - [ ] Monkey
- Expr Socket - Expr Socket
- [ ] Array Mode - [x] LVF Mode
- Math Nodes - Math Nodes
- [ ] Reduce Math - [ ] Reduce Math
- [ ] Transform Math - reindex freq->wl - [x] Transform Math - reindex freq->wl
- Material Data Fitting - Material Data Fitting
- [ ] Data File Import - [ ] Data File Import
- [ ] DataFit Medium - [ ] DataFit Medium
@ -47,10 +47,10 @@
- [ ] Debye Medium - [ ] Debye Medium
- [ ] Anisotropic Medium - [ ] Anisotropic Medium
- Integration - Integration
- [ ] Simulation and Analysis of Maxim's Cavity - [x] Simulation and Analysis of Maxim's Cavity
- Constants - Constants
- [x] Number Constant - [x] Number Constant
- [x] Vector Constant - [x] Vector Constant
- [x] Physical Constant - [x] Physical Constant
- [ ] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`. - [x] Fix many problems by persisting `_enum_cb_cache` and `_str_cb_cache`.

View File

@ -20,6 +20,7 @@ import typing as typ
import jaxtyping as jtyp import jaxtyping as jtyp
import numpy as np import numpy as np
import sympy as sp
import sympy.physics.units as spu import sympy.physics.units as spu
from blender_maxwell.utils import extra_sympy_units as spux from blender_maxwell.utils import extra_sympy_units as spux
@ -110,3 +111,24 @@ class ArrayFlow:
msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}' msg = f'Tried to rescale unitless LazyDataValueRange to unit {unit}'
raise ValueError(msg) raise ValueError(msg)
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
# Compile JAX-Compatible Rescale Function
a = sp.Symbol('a')
rescale_expr = (
spux.scale_to_unit(rescale_func(a * self.unit), new_unit)
if self.unit is not None
else rescale_func(a * self.unit)
)
log.critical([self.unit, new_unit, rescale_expr])
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
# Return ArrayFlow
return ArrayFlow(
values=values[::-1] if reverse else values,
unit=new_unit,
is_sorted=self.is_sorted,
)

View File

@ -31,7 +31,9 @@ log = logger.get(__name__)
@dataclasses.dataclass(frozen=True, kw_only=True) @dataclasses.dataclass(frozen=True, kw_only=True)
class InfoFlow: class InfoFlow:
# Dimension Information ####################
# - Covariant Input
####################
dim_names: list[str] = dataclasses.field(default_factory=list) dim_names: list[str] = dataclasses.field(default_factory=list)
dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field( dim_idx: dict[str, ArrayFlow | LazyArrayRangeFlow] = dataclasses.field(
default_factory=dict default_factory=dict
@ -67,6 +69,9 @@ class InfoFlow:
for dim_idx in self.dim_idx.values() for dim_idx in self.dim_idx.values()
] ]
####################
# - Contravariant Output
####################
# Output Information # Output Information
## TODO: Add PhysicalType ## TODO: Add PhysicalType
output_name: str = dataclasses.field(default_factory=list) output_name: str = dataclasses.field(default_factory=list)
@ -94,6 +99,28 @@ class InfoFlow:
#################### ####################
# - Methods # - Methods
#################### ####################
def replace_dim(
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow]
) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=[
dim_name if dim_name != old_dim_name else new_dim_idx[0]
for dim_name in self.dim_names
],
dim_idx={
(dim_name if dim_name != old_dim_name else new_dim_idx[0]): (
dim_idx if dim_name != old_dim_name else new_dim_idx[1]
)
for dim_name, dim_idx in self.dim_idx.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self: def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
return InfoFlow( return InfoFlow(
# Dimensions # Dimensions

View File

@ -268,6 +268,32 @@ class LazyArrayRangeFlow:
#################### ####################
# - Bound Operations # - Bound Operations
#################### ####################
def rescale(
self, rescale_func, reverse: bool = False, new_unit: spux.Unit | None = None
) -> typ.Self:
new_pre_start = self.start if not reverse else self.stop
new_pre_stop = self.stop if not reverse else self.start
new_start = rescale_func(new_pre_start * self.unit)
new_stop = rescale_func(new_pre_stop * self.unit)
return LazyArrayRangeFlow(
start=(
spux.scale_to_unit(new_start, new_unit)
if new_unit is not None
else new_start
),
stop=(
spux.scale_to_unit(new_stop, new_unit)
if new_unit is not None
else new_stop
),
steps=self.steps,
scaling=self.scaling,
unit=new_unit,
symbols=self.symbols,
)
def rescale_bounds( def rescale_bounds(
self, self,
rescale_func: typ.Callable[ rescale_func: typ.Callable[

View File

@ -566,7 +566,9 @@ class ExtractDataNode(base.MaxwellSimNode):
} }
| { | {
c: ct.ArrayFlow( c: ct.ArrayFlow(
values=xarr.get_index(c).values, unit=spu.radian, is_sorted=True values=xarr.get_index(c).values,
unit=spu.radian,
is_sorted=True,
) )
for c in ['r', 'theta', 'phi'] for c in ['r', 'theta', 'phi']
} }

View File

@ -14,19 +14,19 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import filter_math, map_math, operate_math # , #reduce_math, transform_math from . import filter_math, map_math, operate_math, transform_math
BL_REGISTER = [ BL_REGISTER = [
*operate_math.BL_REGISTER, *operate_math.BL_REGISTER,
*map_math.BL_REGISTER, *map_math.BL_REGISTER,
*filter_math.BL_REGISTER, *filter_math.BL_REGISTER,
# *reduce_math.BL_REGISTER, # *reduce_math.BL_REGISTER,
# *transform_math.BL_REGISTER, *transform_math.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**operate_math.BL_NODES, **operate_math.BL_NODES,
**map_math.BL_NODES, **map_math.BL_NODES,
**filter_math.BL_NODES, **filter_math.BL_NODES,
# **reduce_math.BL_NODES, # **reduce_math.BL_NODES,
# **transform_math.BL_NODES, **transform_math.BL_NODES,
} }

View File

@ -48,7 +48,7 @@ class FilterOperation(enum.StrEnum):
Pin = enum.auto() Pin = enum.auto()
Swap = enum.auto() Swap = enum.auto()
# Interpret # Fold
DimToVec = enum.auto() DimToVec = enum.auto()
DimsToMat = enum.auto() DimsToMat = enum.auto()

View File

@ -439,7 +439,7 @@ class MapMathNode(base.MaxwellSimNode):
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def expr_output_shape(self) -> ct.InfoFlow | None: def expr_output_shape(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info) info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info) has_info = not ct.FlowSignal.check(info)
if has_info: if has_info:
return info.output_shape return info.output_shape
@ -491,9 +491,9 @@ class MapMathNode(base.MaxwellSimNode):
# Compute Sympy Function # Compute Sympy Function
## -> The operation enum directly provides the appropriate function. ## -> The operation enum directly provides the appropriate function.
if has_expr_value and operation is not None: if has_expr_value and operation is not None:
operation.sp_func(expr) return operation.sp_func(expr)
return ct.Flowsignal.FlowPending return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Expr', 'Expr',
@ -529,7 +529,7 @@ class MapMathNode(base.MaxwellSimNode):
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
operation = props['operation'] operation = props['operation']
info = input_sockets['Expr'] info = input_sockets['Expr']
@ -546,8 +546,11 @@ class MapMathNode(base.MaxwellSimNode):
input_sockets={'Expr'}, input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Params}, input_socket_kinds={'Expr': ct.FlowKind.Params},
) )
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal: def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
return input_sockets['Expr'] has_params = not ct.FlowSignal.check(input_sockets['Expr'])
if has_params:
return input_sockets['Expr']
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -20,9 +20,12 @@ import enum
import typing as typ import typing as typ
import bpy import bpy
import jax import jax.numpy as jnp
import sympy as sp
import sympy.physics.units as spu
from blender_maxwell.utils import bl_cache, logger from blender_maxwell.utils import bl_cache, logger, sci_constants, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
@ -31,6 +34,164 @@ from ... import base, events
log = logger.get(__name__) log = logger.get(__name__)
####################
# - Operation Enum
####################
class TransformOperation(enum.StrEnum):
"""Valid operations for the `MapMathNode`.
Attributes:
FreqToVacWL: Transform frequency axes to be indexed by vacuum wavelength.
VacWLToFreq: Transform vacuum wavelength axes to be indexed by frequency.
FFT: Compute the fourier transform of the input expression.
InvFFT: Compute the inverse fourier transform of the input expression.
"""
# Index
FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto()
# Fourier
FFT1D = enum.auto()
InvFFT1D = enum.auto()
# Affine
## TODO
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
TO = TransformOperation
return {
# By Number
TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓',
TO.FFT1D: 't → 𝑓',
TO.InvFFT1D: '𝑓 → t',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
TO = TransformOperation
return (
str(self),
TO.to_name(self),
TO.to_name(self),
TO.to_icon(self),
i,
)
####################
# - Ops from Shape
####################
@staticmethod
def by_element_shape(info: ct.InfoFlow) -> list[typ.Self]:
TO = TransformOperation
operations = []
# Freq <-> VacWL
for dim_name in info.dim_names:
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL)
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq)
# 1D Fourier
if info.dim_names:
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
if last_physical_type == spux.PhysicalType.Time:
operations.append(TO.FFT1D)
if last_physical_type == spux.PhysicalType.Freq:
operations.append(TO.InvFFT1D)
return operations
####################
# - Function Properties
####################
@property
def sp_func(self):
TO = TransformOperation
return {
# Index
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.t, sim_symbols.freq
),
TO.InvFFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.freq, sim_symbols.t
),
}[self]
@property
def jax_func(self):
TO = TransformOperation
return {
# Index
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: jnp.fft(expr),
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
}[self]
def transform_info(self, info: ct.InfoFlow | None) -> ct.InfoFlow | None:
TO = TransformOperation
if not info.dim_names:
return None
return {
# Index
TO.FreqToVacWL: lambda: info.replace_dim(
(f_dim := info.dim_names[-1]),
[
'wl',
info.dim_idx[f_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spu.nanometer,
),
],
),
TO.VacWLToFreq: lambda: info.replace_dim(
(wl_dim := info.dim_names[-1]),
[
'f',
info.dim_idx[wl_dim].rescale(
lambda el: sci_constants.vac_speed_of_light / el,
reverse=True,
new_unit=spux.THz,
),
],
),
# Fourier
TO.FFT1D: lambda: info.replace_dim(
info.dim_names[-1],
[
'f',
ct.LazyArrayRangeFlow(start=0, stop=sp.oo, steps=0, unit=spu.hertz),
],
),
TO.InvFFT1D: info.replace_dim(
info.dim_names[-1],
[
't',
ct.LazyArrayRangeFlow(
start=0, stop=sp.oo, steps=0, unit=spu.second
),
],
),
}.get(self, lambda: info)()
####################
# - Node
####################
class TransformMathNode(base.MaxwellSimNode): class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results. r"""Applies a function to the array as a whole, with arbitrary results.
@ -48,125 +209,153 @@ class TransformMathNode(base.MaxwellSimNode):
bl_label = 'Transform Math' bl_label = 'Transform Math'
input_sockets: typ.ClassVar = { input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
}
input_socket_sets: typ.ClassVar = {
'Fourier': {},
'Affine': {},
'Convolve': {},
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
} }
#################### ####################
# - Properties # - Properties
#################### ####################
operation: enum.StrEnum = bl_cache.BLField( @events.on_value_changed(
enum_cb=lambda self, _: self.search_operations() socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info
return None
operation: TransformOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
) )
def search_operations(self) -> list[ct.BLEnumElement]: def search_operations(self) -> list[ct.BLEnumElement]:
if self.active_socket_set == 'Fourier': # noqa: SIM114 if self.expr_info is not None:
items = [] return [
elif self.active_socket_set == 'Affine': # noqa: SIM114 operation.bl_enum_element(i)
items = [] for i, operation in enumerate(
elif self.active_socket_set == 'Convolve': TransformOperation.by_element_shape(self.expr_info)
items = [] )
else: ]
msg = f'Active socket set {self.active_socket_set} is unknown' return []
raise RuntimeError(msg)
return [(*item, '', i) for i, item in enumerate(items)]
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - Events # - UI
#################### ####################
@events.on_value_changed( def draw_label(self):
prop_name='active_socket_set', if self.operation is not None:
) return 'Transform: ' + TransformOperation.to_name(self.operation)
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - Compute: LazyValueFunc / Array # - Compute: LazyValueFunc / Array
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.Value,
props={'active_socket_set', 'operation'}, props={'operation'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={
'Data': ct.FlowKind.LazyValueFunc,
},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_value(self, props, input_sockets) -> ct.ValueFlow | ct.FlowSignal:
has_data = not ct.FlowSignal.check(input_sockets['Data']) operation = props['operation']
if not has_data or props['operation'] == 'NONE': expr = input_sockets['Expr']
return ct.FlowSignal.FlowPending
mapping_func: typ.Callable[[jax.Array], jax.Array] = { has_expr_value = not ct.FlowSignal.check(expr)
'Fourier': {},
'Affine': {},
'Convolve': {},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data # Compute Sympy Function
return input_sockets['Data'].compose_within( ## -> The operation enum directly provides the appropriate function.
mapping_func, if has_expr_value and operation is not None:
supports_jax=True, return operation.sp_func(expr)
)
return ct.Flowsignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Array, kind=ct.FlowKind.LazyValueFunc,
output_sockets={'Data'}, props={'operation'},
output_socket_kinds={ input_sockets={'Expr'},
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params}, input_socket_kinds={
'Expr': ct.FlowKind.LazyValueFunc,
}, },
) )
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: def compute_func(
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] self, props, input_sockets
params = output_sockets['Data'][ct.FlowKind.Params] ) -> ct.LazyValueFuncFlow | ct.FlowSignal:
operation = props['operation']
expr = input_sockets['Expr']
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): has_expr = not ct.FlowSignal.check(expr)
return ct.ArrayFlow(
values=lazy_value_func.func_jax( if has_expr and operation is not None:
*params.func_args, **params.func_kwargs return expr.compose_within(
), operation.jax_func,
unit=None, supports_jax=True,
) )
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
# - Compute Auxiliary: Info / Params # - FlowKind.Info|Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'}, props={'operation'},
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Expr': ct.FlowKind.Info},
) )
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: def compute_info(
info = input_sockets['Data'] self, props: dict, input_sockets: dict
if ct.FlowSignal.check(info): ) -> ct.InfoFlow | typ.Literal[ct.FlowSignal.FlowPending]:
return ct.FlowSignal.FlowPending operation = props['operation']
info = input_sockets['Expr']
return info has_info = not ct.FlowSignal.check(info)
if has_info and operation is not None:
transformed_info = operation.transform_info(info)
if transformed_info is None:
return ct.FlowSignal.FlowPending
return transformed_info
return ct.FlowSignal.FlowPending
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Expr',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
input_sockets={'Data'}, input_sockets={'Expr'},
input_socket_kinds={'Data': ct.FlowKind.Params}, input_socket_kinds={'Expr': ct.FlowKind.Params},
) )
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal: def compute_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
return input_sockets['Data'] has_params = not ct.FlowSignal.check(input_sockets['Expr'])
if has_params:
return input_sockets['Expr']
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -225,8 +225,24 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
## - Properties ## - Properties
##################### #####################
@events.on_value_changed(
socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property() @bl_cache.cached_bl_property()
def input_info(self) -> ct.InfoFlow | None: def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info) info = self._compute_input('Expr', kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info): if not ct.FlowSignal.check(info):
return info return info
@ -235,7 +251,7 @@ class VizNode(base.MaxwellSimNode):
viz_mode: enum.StrEnum = bl_cache.BLField( viz_mode: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_viz_modes(), enum_cb=lambda self, _: self.search_viz_modes(),
cb_depends_on={'input_info'}, cb_depends_on={'expr_info'},
) )
viz_target: enum.StrEnum = bl_cache.BLField( viz_target: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_targets(), enum_cb=lambda self, _: self.search_targets(),
@ -251,7 +267,7 @@ class VizNode(base.MaxwellSimNode):
## - Searchers ## - Searchers
##################### #####################
def search_viz_modes(self) -> list[ct.BLEnumElement]: def search_viz_modes(self) -> list[ct.BLEnumElement]:
if self.input_info is not None: if self.expr_info is not None:
return [ return [
( (
viz_mode, viz_mode,
@ -260,7 +276,7 @@ class VizNode(base.MaxwellSimNode):
VizMode.to_icon(viz_mode), VizMode.to_icon(viz_mode),
i, i,
) )
for i, viz_mode in enumerate(VizMode.valid_modes_for(self.input_info)) for i, viz_mode in enumerate(VizMode.valid_modes_for(self.expr_info))
] ]
return [] return []
@ -284,6 +300,12 @@ class VizNode(base.MaxwellSimNode):
##################### #####################
## - UI ## - UI
##################### #####################
def draw_label(self):
if self.viz_mode is not None:
return 'Viz: ' + self.sim_node_name
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout): def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, self.blfields['viz_mode'], text='') col.prop(self, self.blfields['viz_mode'], text='')
col.prop(self, self.blfields['viz_target'], text='') col.prop(self, self.blfields['viz_target'], text='')
@ -338,11 +360,23 @@ class VizNode(base.MaxwellSimNode):
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
self.input_info = bl_cache.Signal.InvalidateCache
##################### #####################
## - Plotting ## - Plotting
##################### #####################
@events.computes_output_socket(
'Preview',
kind=ct.FlowKind.Value,
# Loaded
props={'viz_mode', 'viz_target', 'colormap'},
input_sockets={'Expr'},
input_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params}
},
all_loose_input_sockets=True,
)
def compute_dummy_value(self, props, input_sockets, loose_input_sockets):
return ct.FlowSignal.NoFlow
@events.on_show_plot( @events.on_show_plot(
managed_objs={'plot'}, managed_objs={'plot'},
props={'viz_mode', 'viz_target', 'colormap'}, props={'viz_mode', 'viz_target', 'colormap'},

View File

@ -787,7 +787,7 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
altered_socket_kinds = set() altered_socket_kinds = set()
# Invalidate Caches on DataChanged # Invalidate Caches on DataChanged
if event == ct.FlowEvent.DataChanged: if event is ct.FlowEvent.DataChanged:
input_socket_name = socket_name ## Trigger direction is forwards input_socket_name = socket_name ## Trigger direction is forwards
# Invalidate Input Socket Cache # Invalidate Input Socket Cache
@ -861,8 +861,18 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
# ) # )
event_method(self) event_method(self)
# DataChanged Propagation Stop: No Altered Socket Kinds
## -> If no FlowKinds were altered, then propagation makes no sense.
## -> Semantically, **nothing has changed** == no DataChanged!
if event is ct.FlowEvent.DataChanged and not altered_socket_kinds:
return
# Constrain ShowPlot to First Node: Workaround
if event is ct.FlowEvent.ShowPlot:
return
# Propagate Event to All Sockets in "Trigger Direction" # Propagate Event to All Sockets in "Trigger Direction"
## The trigger chain goes node/socket/node/socket/... ## -> The trigger chain goes node/socket/socket/node/socket/...
if not stop_propagation: if not stop_propagation:
direc = ct.FlowEvent.flow_direction[event] direc = ct.FlowEvent.flow_direction[event]
triggered_sockets = self._bl_sockets(direc=direc) triggered_sockets = self._bl_sockets(direc=direc)

View File

@ -139,8 +139,10 @@ class ViewerNode(base.MaxwellSimNode):
# Unset Plot if Nothing Plotted # Unset Plot if Nothing Plotted
with node_tree.replot(): with node_tree.replot():
if props['auto_plot']: if props['auto_plot'] and self.inputs['Any'].is_linked:
self.trigger_event(ct.FlowEvent.ShowPlot) self.inputs['Any'].links[0].from_socket.node.trigger_event(
ct.FlowEvent.ShowPlot
)
@events.on_value_changed( @events.on_value_changed(
socket_name='Any', socket_name='Any',

View File

@ -82,8 +82,8 @@ class TemporalShapeNode(base.MaxwellSimNode):
default_steps=100, default_steps=100,
), ),
'Envelope': sockets.ExprSocketDef( 'Envelope': sockets.ExprSocketDef(
default_symbols=[sim_symbols.t_ps], default_symbols=[sim_symbols.t],
default_value=10 * sim_symbols.t_ps.sp_symbol, default_value=10 * sim_symbols.t.sp_symbol,
), ),
}, },
} }

View File

@ -251,7 +251,8 @@ class BLInstance:
if prop_name in deps: if prop_name in deps:
for dst_prop_name in deps[prop_name]: for dst_prop_name in deps[prop_name]:
log.debug( log.debug(
'Property %s is invalidating %s', '%s: "%s" is invalidating "%s"',
self.bl_label,
prop_name, prop_name,
dst_prop_name, dst_prop_name,
) )

View File

@ -204,9 +204,9 @@ class CommonSimSymbol(enum.StrEnum):
""" """
X = enum.auto() X = enum.auto()
TimePS = enum.auto() Time = enum.auto()
WavelengthNM = enum.auto() Wavelength = enum.auto()
FrequencyTHZ = enum.auto() Frequency = enum.auto()
@staticmethod @staticmethod
def to_name(v: typ.Self) -> str: def to_name(v: typ.Self) -> str:
@ -241,9 +241,9 @@ class CommonSimSymbol(enum.StrEnum):
CSS = CommonSimSymbol CSS = CommonSimSymbol
return { return {
CSS.X: SSN.LowerX, CSS.X: SSN.LowerX,
CSS.TimePS: SSN.LowerT, CSS.Time: SSN.LowerT,
CSS.WavelengthNM: SSN.Wavelength, CSS.Wavelength: SSN.Wavelength,
CSS.FrequencyTHZ: SSN.Frequency, CSS.Frequency: SSN.Frequency,
}[self] }[self]
@property @property
@ -260,7 +260,7 @@ class CommonSimSymbol(enum.StrEnum):
interval_inf=(True, True), interval_inf=(True, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
CSS.TimePS: SimSymbol( CSS.Time: SimSymbol(
sim_node_name=self.sim_symbol_name, sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Time, physical_type=spux.PhysicalType.Time,
@ -269,7 +269,7 @@ class CommonSimSymbol(enum.StrEnum):
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(True, False), interval_closed=(True, False),
), ),
CSS.WavelengthNM: SimSymbol( CSS.Wavelength: SimSymbol(
sim_node_name=self.sim_symbol_name, sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Length, physical_type=spux.PhysicalType.Length,
@ -278,11 +278,10 @@ class CommonSimSymbol(enum.StrEnum):
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
), ),
CSS.FrequencyTHZ: SimSymbol( CSS.Frequency: SimSymbol(
sim_node_name=self.sim_symbol_name, sim_node_name=self.sim_symbol_name,
mathtype=spux.MathType.Real, mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.Freq, physical_type=spux.PhysicalType.Freq,
## TODO: Unit of THz
interval_finite=(0, sys.float_info.max), interval_finite=(0, sys.float_info.max),
interval_inf=(False, True), interval_inf=(False, True),
interval_closed=(False, False), interval_closed=(False, False),
@ -294,6 +293,6 @@ class CommonSimSymbol(enum.StrEnum):
# - Selected Direct Access # - Selected Direct Access
#################### ####################
x = CommonSimSymbol.X.sim_symbol x = CommonSimSymbol.X.sim_symbol
t_ps = CommonSimSymbol.TimePS.sim_symbol t = CommonSimSymbol.Time.sim_symbol
wl_nm = CommonSimSymbol.WavelengthNM.sim_symbol wl = CommonSimSymbol.Wavelength.sim_symbol
freq_thz = CommonSimSymbol.FrequencyTHZ.sim_symbol freq = CommonSimSymbol.Frequency.sim_symbol