feat: first fully parameterized simulation
We also believe we've fixed the crash related to read-during-write of certain properties, by altering `BLField`'s invalidation to only invalidate the non-persistent cache when it isn't suppressed. This has no effect on the non-persistent cache, which is invalidated correctly anyway during the write() that suppression is defined for; however, by not stripping it away while that write() is being implemented by Blender, CPython's builtin atomic `dict` operations (especially `.get()`) manage to provide the thread-safety that we seem to have been missing. We tested the race-condition by the usual "drive everything from one parameter", and weren't able to trigger one. Here's to hoping. We've also started to become large fans of the "`Value` from `Func`" principle, which both allows `Value` to pick up on fully realized symbols in preceding flows, while also guaranteeing that we're not maintaining duplicate code-paths. However, the cost of the organizational reduction is that it is admittedly slower (in the interactive hot-loop) than we'd like - in large part due to having to realize a `FuncFlow` every time. However, certain optimizations have a lot of potential: - Reducing extraneous invalidations of `.lazy_func` on the `ExprSocket`, though some cleverness with `.output_sym` is needed. - Pure performance work on all `Flow` objects. Surprisingly, they seem to mean a lot.main
parent
3624d2ff45
commit
02e309db4d
|
@ -624,7 +624,7 @@ class FuncFlow(pyd.BaseModel):
|
|||
# Compose Unit-Converted FuncFlow
|
||||
return self.compose_within(
|
||||
enclosing_func=unit_convert_func,
|
||||
supports_jax=True,
|
||||
supports_jax=self.supports_jax,
|
||||
enclosing_func_output=self.func_output.update(unit=None),
|
||||
)
|
||||
|
||||
|
|
|
@ -261,7 +261,7 @@ class ManagedBLImage(base.ManagedObj):
|
|||
dpi: int | None = None,
|
||||
bl_select: bool = False,
|
||||
):
|
||||
times = ['START', time.perf_counter()]
|
||||
times = [time.perf_counter()]
|
||||
|
||||
# Compute Plot Dimensions
|
||||
# aspect_ratio, _dpi, _width_inches, _height_inches, width_px, height_px = (
|
||||
|
|
|
@ -86,9 +86,8 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
),
|
||||
},
|
||||
}
|
||||
output_socket_sets: typ.ClassVar = {
|
||||
'Freq Domain': {'Freq Monitor': sockets.MaxwellMonitorSocketDef()},
|
||||
'Time Domain': {'Time Monitor': sockets.MaxwellMonitorSocketDef()},
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Monitor': sockets.MaxwellMonitorSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
managed_obj_types: typ.ClassVar = {
|
||||
|
@ -107,111 +106,190 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
layout.prop(self, self.blfields['fields'], expand=True)
|
||||
|
||||
####################
|
||||
# - Output
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Freq Monitor',
|
||||
props={'sim_node_name', 'fields'},
|
||||
'Monitor',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
output_sockets={'Monitor'},
|
||||
output_socket_kinds={'Monitor': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Monitor'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Monitor'][ct.FlowKind.Params]
|
||||
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params, disallow_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Monitor',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
props={'active_socket_set', 'sim_node_name', 'fields'},
|
||||
input_sockets={
|
||||
'Center',
|
||||
'Size',
|
||||
'Stride',
|
||||
'Freqs',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'Freqs': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Center': 'Tidy3DUnits',
|
||||
'Size': 'Tidy3DUnits',
|
||||
'Freqs': 'Tidy3DUnits',
|
||||
},
|
||||
)
|
||||
def compute_freq_monitor(
|
||||
self,
|
||||
input_sockets: dict,
|
||||
props: dict,
|
||||
unit_systems: dict,
|
||||
) -> td.FieldMonitor:
|
||||
log.info(
|
||||
'Computing FieldMonitor (name="%s") with center="%s", size="%s"',
|
||||
props['sim_node_name'],
|
||||
input_sockets['Center'],
|
||||
input_sockets['Size'],
|
||||
)
|
||||
return td.FieldMonitor(
|
||||
center=input_sockets['Center'],
|
||||
size=input_sockets['Size'],
|
||||
name=props['sim_node_name'],
|
||||
interval_space=tuple(input_sockets['Stride']),
|
||||
freqs=input_sockets['Freqs'].realize().values,
|
||||
fields=props['fields'],
|
||||
)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Time Monitor',
|
||||
props={'sim_node_name', 'fields'},
|
||||
input_sockets={
|
||||
'Center',
|
||||
'Size',
|
||||
'Stride',
|
||||
't Range',
|
||||
't Stride',
|
||||
},
|
||||
input_socket_kinds={
|
||||
't Range': ct.FlowKind.Range,
|
||||
},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Center': 'Tidy3DUnits',
|
||||
'Size': 'Tidy3DUnits',
|
||||
't Range': 'Tidy3DUnits',
|
||||
'Center': ct.FlowKind.Func,
|
||||
'Size': ct.FlowKind.Func,
|
||||
'Stride': ct.FlowKind.Func,
|
||||
'Freqs': ct.FlowKind.Func,
|
||||
't Range': ct.FlowKind.Func,
|
||||
't Stride': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_time_monitor(
|
||||
self,
|
||||
input_sockets: dict,
|
||||
props: dict,
|
||||
unit_systems: dict,
|
||||
) -> td.FieldMonitor:
|
||||
log.info(
|
||||
'Computing FieldMonitor (name="%s") with center="%s", size="%s"',
|
||||
props['sim_node_name'],
|
||||
input_sockets['Center'],
|
||||
input_sockets['Size'],
|
||||
def compute_func(self, props, input_sockets) -> td.FieldMonitor:
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
stride = input_sockets['Stride']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_stride = not ct.FlowSignal.check(stride)
|
||||
|
||||
if has_center and has_size and has_stride:
|
||||
name = props['sim_node_name']
|
||||
fields = props['fields']
|
||||
|
||||
common_func_flow = (
|
||||
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| stride
|
||||
)
|
||||
return td.FieldTimeMonitor(
|
||||
center=input_sockets['Center'],
|
||||
size=input_sockets['Size'],
|
||||
name=props['sim_node_name'],
|
||||
interval_space=tuple(input_sockets['Stride']),
|
||||
start=input_sockets['t Range'].realize_start(),
|
||||
stop=input_sockets['t Range'].realize_stop(),
|
||||
interval=input_sockets['t Stride'],
|
||||
fields=props['fields'],
|
||||
|
||||
match props['active_socket_set']:
|
||||
case 'Freq Domain':
|
||||
freqs = input_sockets['Freqs']
|
||||
has_freqs = not ct.FlowSignal.check(freqs)
|
||||
|
||||
if has_freqs:
|
||||
return (
|
||||
common_func_flow
|
||||
| freqs.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
).compose_within(
|
||||
lambda els: td.FieldMonitor(
|
||||
center=els[0].flatten().tolist(),
|
||||
size=els[1].flatten().tolist(),
|
||||
name=name,
|
||||
interval_space=els[2].flatten().tolist(),
|
||||
freqs=els[3].flatten(),
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
case 'Time Domain':
|
||||
t_range = input_sockets['t Range']
|
||||
t_stride = input_sockets['t Stride']
|
||||
|
||||
has_t_range = not ct.FlowSignal.check(t_range)
|
||||
has_t_stride = not ct.FlowSignal.check(t_stride)
|
||||
|
||||
if has_t_range and has_t_stride:
|
||||
return (
|
||||
common_func_flow
|
||||
| t_range.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| t_stride.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
).compose_within(
|
||||
lambda els: td.FieldTimeMonitor(
|
||||
center=els[0].flatten().tolist(),
|
||||
size=els[1].flatten().tolist(),
|
||||
name=name,
|
||||
interval_space=els[2].flatten().tolist(),
|
||||
start=els[3][0],
|
||||
stop=els[3][-1],
|
||||
interval=els[4],
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Monitor',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
props={'active_socket_set'},
|
||||
input_sockets={
|
||||
'Center',
|
||||
'Size',
|
||||
'Stride',
|
||||
'Freqs',
|
||||
't Range',
|
||||
't Stride',
|
||||
},
|
||||
input_socket_kinds={
|
||||
'Center': ct.FlowKind.Params,
|
||||
'Size': ct.FlowKind.Params,
|
||||
'Stride': ct.FlowKind.Params,
|
||||
'Freqs': ct.FlowKind.Params,
|
||||
't Range': ct.FlowKind.Params,
|
||||
't Stride': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_params(self, props, input_sockets) -> None:
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
stride = input_sockets['Stride']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_stride = not ct.FlowSignal.check(stride)
|
||||
|
||||
if has_center and has_size and has_stride:
|
||||
common_params = center | size | stride
|
||||
match props['active_socket_set']:
|
||||
case 'Freq Domain':
|
||||
freqs = input_sockets['Freqs']
|
||||
has_freqs = not ct.FlowSignal.check(freqs)
|
||||
|
||||
if has_freqs:
|
||||
return common_params | freqs
|
||||
|
||||
case 'Time Domain':
|
||||
t_range = input_sockets['t Range']
|
||||
t_stride = input_sockets['t Stride']
|
||||
|
||||
has_t_range = not ct.FlowSignal.check(t_range)
|
||||
has_t_stride = not ct.FlowSignal.check(t_stride)
|
||||
|
||||
if has_t_range and has_t_stride:
|
||||
return common_params | t_range | t_stride
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Preview
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Time Monitor',
|
||||
'Monitor',
|
||||
kind=ct.FlowKind.Previews,
|
||||
# Loaded
|
||||
props={'sim_node_name'},
|
||||
output_sockets={'Monitor'},
|
||||
output_socket_kinds={'Monitor': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_previews_time(self, props):
|
||||
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
|
||||
def compute_previews(self, props, output_sockets):
|
||||
output_params = output_sockets['Monitor']
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Freq Monitor',
|
||||
kind=ct.FlowKind.Previews,
|
||||
# Loaded
|
||||
props={'sim_node_name'},
|
||||
)
|
||||
def compute_previews_freq(self, props):
|
||||
if has_output_params and not output_params.symbols:
|
||||
return ct.PreviewsFlow(bl_object_names={props['sim_node_name']})
|
||||
return ct.PreviewsFlow()
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
|
@ -220,28 +298,30 @@ class EHFieldMonitorNode(base.MaxwellSimNode):
|
|||
# Loaded
|
||||
managed_objs={'modifier'},
|
||||
input_sockets={'Center', 'Size'},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
scale_input_sockets={
|
||||
'Center': 'BlenderUnits',
|
||||
},
|
||||
output_sockets={'Monitor'},
|
||||
output_socket_kinds={'Monitor': ct.FlowKind.Params},
|
||||
)
|
||||
def on_inputs_changed(
|
||||
self,
|
||||
managed_objs,
|
||||
input_sockets,
|
||||
unit_systems,
|
||||
):
|
||||
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
output_params = output_sockets['Monitor']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_center and has_size and has_output_params and not output_params.symbols:
|
||||
# Push Input Values to GeoNodes Modifier
|
||||
managed_objs['modifier'].bl_modifier(
|
||||
'NODES',
|
||||
{
|
||||
'node_group': import_geonodes(GeoNodes.MonitorEHField),
|
||||
'unit_system': unit_systems['BlenderUnits'],
|
||||
'unit_system': ct.UNITS_BLENDER,
|
||||
'inputs': {
|
||||
'Size': input_sockets['Size'],
|
||||
'Size': size,
|
||||
},
|
||||
},
|
||||
location=input_sockets['Center'],
|
||||
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -163,7 +163,9 @@ class CombineNode(base.MaxwellSimNode):
|
|||
if not ct.FlowSignal.check(inp)
|
||||
]
|
||||
if func_flows:
|
||||
return func_flows
|
||||
return functools.reduce(
|
||||
lambda a, b: a | b, func_flows
|
||||
).compose_within(lambda els: list(els))
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
case (ct.FlowKind.Func, ct.FlowKind.Params):
|
||||
|
@ -171,11 +173,13 @@ class CombineNode(base.MaxwellSimNode):
|
|||
params_flow
|
||||
for inp_sckname in self.inputs.keys() # noqa: SIM118
|
||||
if not ct.FlowSignal.check(
|
||||
params_flow := self._compute_input(inp_sckname, kind='params')
|
||||
params_flow := self._compute_input(
|
||||
inp_sckname, kind=ct.FlowKind.Params
|
||||
)
|
||||
)
|
||||
]
|
||||
if params_flows:
|
||||
return params_flows
|
||||
return functools.reduce(lambda a, b: a | b, params_flows)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
|
|
@ -57,10 +57,7 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
'Single': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Value),
|
||||
},
|
||||
'Batch': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Array),
|
||||
},
|
||||
'Lazy': {
|
||||
'Func': {
|
||||
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=ct.FlowKind.Func),
|
||||
},
|
||||
}
|
||||
|
@ -101,7 +98,7 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
# Declare Loose Sockets that Realize Symbols
|
||||
## -> This happens if Params contains not-yet-realized symbols.
|
||||
active_socket_set = props['active_socket_set']
|
||||
if active_socket_set in ['Value', 'Batch'] and has_params and params.symbols:
|
||||
if active_socket_set == 'Single' and has_params and params.symbols:
|
||||
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
|
||||
self.loose_input_sockets = {
|
||||
sym.name: sockets.ExprSocketDef(
|
||||
|
@ -128,51 +125,19 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
'Sim',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'Sources': ct.FlowKind.Array,
|
||||
'Structures': ct.FlowKind.Array,
|
||||
'Monitors': ct.FlowKind.Array,
|
||||
},
|
||||
output_sockets={'Sim'},
|
||||
output_socket_kinds={'Sim': ct.FlowKind.Params},
|
||||
output_socket_kinds={'Sim': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_fdtd_sim_value(
|
||||
self, props, input_sockets, output_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single FDTD simulation definition, so long as the inputs are neither symbolic or differentiable."""
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
output_params = output_sockets['Sim']
|
||||
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Sim'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Sim'][ct.FlowKind.Params]
|
||||
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
differentiable = props['differentiable']
|
||||
if (
|
||||
has_sim_domain
|
||||
and has_sources
|
||||
and has_structures
|
||||
and has_bounds
|
||||
and has_monitors
|
||||
and has_output_params
|
||||
and not differentiable
|
||||
):
|
||||
return td.Simulation(
|
||||
**sim_domain,
|
||||
sources=sources,
|
||||
structures=structures,
|
||||
boundary_spec=bounds,
|
||||
monitors=monitors,
|
||||
)
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params, disallow_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
|
@ -185,6 +150,8 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'BCs': ct.FlowKind.Func,
|
||||
'Domain': ct.FlowKind.Func,
|
||||
'Sources': ct.FlowKind.Func,
|
||||
'Structures': ct.FlowKind.Func,
|
||||
'Monitors': ct.FlowKind.Func,
|
||||
|
@ -196,17 +163,17 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
self, props, input_sockets, output_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||
bounds = input_sockets['BCs']
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
output_params = output_sockets['Sim']
|
||||
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
|
@ -220,28 +187,16 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
):
|
||||
differentiable = props['differentiable']
|
||||
if differentiable:
|
||||
raise NotImplementedError
|
||||
|
||||
return (
|
||||
sim_domain | sources | structures | bounds | monitors
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: tdadj.JaxSimulation(
|
||||
**els[0],
|
||||
sources=els[1],
|
||||
structures=els[2]['static'],
|
||||
input_structures=els[2]['differentiable'],
|
||||
boundary_spec=els[3],
|
||||
monitors=els[4]['static'],
|
||||
output_monitors=els[4]['differentiable'],
|
||||
),
|
||||
supports_jax=True,
|
||||
)
|
||||
return (
|
||||
sim_domain | sources | structures | bounds | monitors
|
||||
bounds | sim_domain | sources | structures | monitors
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: td.Simulation(
|
||||
**els[0],
|
||||
sources=els[1],
|
||||
structures=els[2],
|
||||
boundary_spec=els[3],
|
||||
boundary_spec=els[0],
|
||||
**els[1],
|
||||
sources=els[2],
|
||||
structures=els[3],
|
||||
monitors=els[4],
|
||||
),
|
||||
supports_jax=False,
|
||||
|
@ -258,6 +213,8 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
props={'differentiable'},
|
||||
input_sockets={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
|
||||
input_socket_kinds={
|
||||
'BCs': ct.FlowKind.Params,
|
||||
'Domain': ct.FlowKind.Params,
|
||||
'Sources': ct.FlowKind.Params,
|
||||
'Structures': ct.FlowKind.Params,
|
||||
'Monitors': ct.FlowKind.Params,
|
||||
|
@ -267,31 +224,26 @@ class FDTDSimNode(base.MaxwellSimNode):
|
|||
self, props, input_sockets
|
||||
) -> td.Simulation | tdadj.JaxSimulation | ct.FlowSignal:
|
||||
"""Compute a single simulation, given that all inputs are non-symbolic."""
|
||||
bounds = input_sockets['BCs']
|
||||
sim_domain = input_sockets['Domain']
|
||||
sources = input_sockets['Sources']
|
||||
structures = input_sockets['Structures']
|
||||
bounds = input_sockets['BCs']
|
||||
monitors = input_sockets['Monitors']
|
||||
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_sim_domain = not ct.FlowSignal.check(sim_domain)
|
||||
has_sources = not ct.FlowSignal.check(sources)
|
||||
has_structures = not ct.FlowSignal.check(structures)
|
||||
has_bounds = not ct.FlowSignal.check(bounds)
|
||||
has_monitors = not ct.FlowSignal.check(monitors)
|
||||
|
||||
if (
|
||||
has_sim_domain
|
||||
has_bounds
|
||||
and has_sim_domain
|
||||
and has_sources
|
||||
and has_structures
|
||||
and has_bounds
|
||||
and has_monitors
|
||||
):
|
||||
# Determine Differentiable Match
|
||||
## -> 'structures' is diff when **any** are diff.
|
||||
## -> 'monitors' is also diff when **any** are diff.
|
||||
## -> Only parameters through diff structs can be diff'ed by.
|
||||
## -> Similarly, only diff monitors will have gradients computed.
|
||||
return sim_domain | sources | structures | bounds | monitors
|
||||
return bounds | sim_domain | sources | structures | monitors
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Angled Source': sockets.MaxwellSourceSocketDef(),
|
||||
'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=ct.FlowKind.Func),
|
||||
}
|
||||
|
||||
managed_obj_types: typ.ClassVar = {
|
||||
|
@ -103,8 +103,8 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
)
|
||||
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Source'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Source'][ct.FlowKind.Params]
|
||||
output_func = output_sockets['Angled Source'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Angled Source'][ct.FlowKind.Params]
|
||||
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
@ -122,11 +122,11 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
# Loaded
|
||||
props={'sim_node_name', 'injection_axis', 'injection_direction'},
|
||||
input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Center': 'Tidy3DUnits',
|
||||
'Spherical': 'Tidy3DUnits',
|
||||
'Pol ∡': 'Tidy3DUnits',
|
||||
input_socket_kinds={
|
||||
'Temporal Shape': ct.FlowKind.Func,
|
||||
'Center': ct.FlowKind.Func,
|
||||
'Spherical': ct.FlowKind.Func,
|
||||
'Pol ∡': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_func(self, props, input_sockets) -> None:
|
||||
|
@ -157,17 +157,47 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
).compose_within(
|
||||
lambda els: td.PlaneWave(
|
||||
name=name,
|
||||
center=els[0],
|
||||
center=els[0].flatten().tolist(),
|
||||
size=size,
|
||||
source_time=els[1],
|
||||
direction=inj_dir,
|
||||
angle_theta=els[2][0],
|
||||
angle_phi=els[2][1],
|
||||
angle_theta=els[2][0].item(0),
|
||||
angle_phi=els[2][1].item(0),
|
||||
pol_angle=els[3],
|
||||
)
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Angled Source',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
input_sockets={'Temporal Shape', 'Center', 'Spherical', 'Pol ∡'},
|
||||
input_socket_kinds={
|
||||
'Temporal Shape': ct.FlowKind.Params,
|
||||
'Center': ct.FlowKind.Params,
|
||||
'Spherical': ct.FlowKind.Params,
|
||||
'Pol ∡': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_params(self, input_sockets) -> None:
|
||||
center = input_sockets['Center']
|
||||
temporal_shape = input_sockets['Temporal Shape']
|
||||
spherical = input_sockets['Spherical']
|
||||
pol_ang = input_sockets['Pol ∡']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_temporal_shape = not ct.FlowSignal.check(temporal_shape)
|
||||
has_spherical = not ct.FlowSignal.check(spherical)
|
||||
has_pol_ang = not ct.FlowSignal.check(pol_ang)
|
||||
|
||||
if has_center and has_temporal_shape and has_spherical and has_pol_ang:
|
||||
return center | temporal_shape | spherical | pol_ang
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Preview - Changes to Input Sockets
|
||||
####################
|
||||
|
@ -176,11 +206,11 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
kind=ct.FlowKind.Previews,
|
||||
# Loaded
|
||||
props={'sim_node_name'},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
output_sockets={'Angled Source'},
|
||||
output_socket_kinds={'Angled Source': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_previews(self, props, output_sockets):
|
||||
output_params = output_sockets['Structure']
|
||||
output_params = output_sockets['Angled Source']
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_params and not output_params.symbols:
|
||||
|
@ -199,7 +229,9 @@ class PlaneWaveSourceNode(base.MaxwellSimNode):
|
|||
output_sockets={'Angled Source'},
|
||||
output_socket_kinds={'Angled Source': ct.FlowKind.Params},
|
||||
)
|
||||
def on_inputs_changed(self, managed_objs, props, input_sockets, output_sockets):
|
||||
def on_previewable_changed(
|
||||
self, managed_objs, props, input_sockets, output_sockets
|
||||
):
|
||||
center = input_sockets['Center']
|
||||
spherical = input_sockets['Spherical']
|
||||
pol_ang = input_sockets['Pol ∡']
|
||||
|
|
|
@ -172,11 +172,11 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
|||
kind=ct.FlowKind.Previews,
|
||||
# Loaded
|
||||
props={'sim_node_name'},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
output_sockets={'Source'},
|
||||
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_previews(self, props, output_sockets):
|
||||
output_params = output_sockets['Structure']
|
||||
output_params = output_sockets['Source']
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_params and not output_params.symbols:
|
||||
|
@ -184,6 +184,7 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
|||
return ct.PreviewsFlow()
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name={'Center'},
|
||||
prop_name='pol',
|
||||
run_on_init=True,
|
||||
|
@ -194,7 +195,7 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
|||
output_sockets={'Source'},
|
||||
output_socket_kinds={'Source': ct.FlowKind.Params},
|
||||
)
|
||||
def on_inputs_changed(
|
||||
def on_previewable_changed(
|
||||
self, managed_objs, props, input_sockets, output_sockets
|
||||
) -> None:
|
||||
SFP = ct.SimFieldPols
|
||||
|
@ -220,8 +221,10 @@ class PointDipoleSourceNode(base.MaxwellSimNode):
|
|||
'NODES',
|
||||
{
|
||||
'node_group': import_geonodes(GeoNodes.SourcePointDipole),
|
||||
'inputs': {'Axis': axis},
|
||||
'unit_system': ct.UNITS_BLENDER,
|
||||
'inputs': {
|
||||
'Axis': axis,
|
||||
},
|
||||
},
|
||||
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||
)
|
||||
|
|
|
@ -213,8 +213,8 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
socket_name={'Center', 'Size'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
input_sockets={'Center', 'Size'},
|
||||
managed_objs={'modifier'},
|
||||
input_sockets={'Center', 'Size'},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
|
|
|
@ -282,9 +282,9 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
|
|||
|
||||
# Property Callbacks: Per-Socket
|
||||
## -> NOTE: User-defined handlers might recurse on_prop_changed.
|
||||
self.is_initializing = True
|
||||
# self.is_initializing = True
|
||||
self.on_socket_props_changed(set_of_cleared_blfields)
|
||||
self.is_initializing = False
|
||||
# self.is_initializing = False
|
||||
|
||||
# Trigger Event
|
||||
## -> Before SocketDef.postinit(), never emit DataChanged.
|
||||
|
|
|
@ -408,7 +408,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
## NOTE: Depends on suppressed on_prop_changed
|
||||
if ('selected_value_range', 'invalidate') in cleared_blfields:
|
||||
self.active_kind = self.selected_value_range
|
||||
self.on_active_kind_changed()
|
||||
# self.on_active_kind_changed()
|
||||
|
||||
# Conditional Unit-Conversion
|
||||
## -> This is niche functionality, but the only way to convert units.
|
||||
|
@ -434,6 +434,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
self.lazy_range = self.lazy_range.correct_unit(prev_unit)
|
||||
|
||||
self.prev_unit = self.active_unit
|
||||
# self.unit = bl_cache.Signal.InvalidateCache
|
||||
|
||||
####################
|
||||
# - Value Utilities
|
||||
|
@ -843,7 +844,9 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'sorted_symbols', 'output_sym'})
|
||||
@bl_cache.cached_bl_property(
|
||||
depends_on={'sorted_symbols', 'output_sym', 'lazy_range'}
|
||||
)
|
||||
def info(self) -> ct.InfoFlow:
|
||||
r"""Returns parameter symbols/values to accompany `self.lazy_func`.
|
||||
|
||||
|
|
|
@ -307,6 +307,13 @@ class BLField:
|
|||
## -> As a result, the value must be reloaded from the property.
|
||||
## The 'on_prop_changed' method on the bl_instance might also be called.
|
||||
if value is Signal.InvalidateCache or value is Signal.InvalidateCacheNoUpdate:
|
||||
# Invalidate Non-Persistent Cache
|
||||
## -> Invalidations via on_bl_prop_changed may come by here.
|
||||
## -> In many cases, we may have suppressed that BLProp.write().
|
||||
## -> BLProp.write() itself guarantees inv. of non-persist cache.
|
||||
## -> In the meantime, don't strip away thread safety!
|
||||
## -> (Dict ops are thread-safe in CPython; Blender Props are NOT)
|
||||
if not self.suppressed_update.get(bl_instance.instance_id, False):
|
||||
self.bl_prop.invalidate_nonpersist(bl_instance)
|
||||
|
||||
# Trigger Update Chain
|
||||
|
|
|
@ -182,9 +182,12 @@ class CachedBLProperty:
|
|||
# Fill Caches
|
||||
## -> persist=True: Fill Persist and Non-Persist Cache
|
||||
## -> persist=False: Fill Non-Persist Cache
|
||||
if not self.suppressed_update.get(bl_instance.instance_id, False):
|
||||
if self.persist:
|
||||
with self.suppress_update(bl_instance):
|
||||
self.bl_prop.write(bl_instance, self.getter_method(bl_instance))
|
||||
self.bl_prop.write(
|
||||
bl_instance, self.getter_method(bl_instance)
|
||||
)
|
||||
|
||||
else:
|
||||
self.bl_prop.write_nonpersist(
|
||||
|
@ -195,7 +198,7 @@ class CachedBLProperty:
|
|||
## -> Use InvalidateCacheNoUpdate to explicitly disable update.
|
||||
## -> If 'suppress_update' context manager is active, don't update.
|
||||
if value is Signal.InvalidateCache and not self.suppressed_update.get(
|
||||
bl_instance.instance_id
|
||||
bl_instance.instance_id, False
|
||||
):
|
||||
bl_instance.on_prop_changed(self.bl_prop.name)
|
||||
|
||||
|
|
|
@ -1408,7 +1408,11 @@ class PhysicalType(enum.StrEnum):
|
|||
ValueError: If no `PhysicalType` could be matched, and `optional` is `False`.
|
||||
"""
|
||||
if unit is None:
|
||||
return ct.PhysicalType.NonPhysical
|
||||
return PhysicalType.NonPhysical
|
||||
|
||||
## TODO_ This enough?
|
||||
if unit in [spu.radian, spu.degree]:
|
||||
return PhysicalType.Angle
|
||||
|
||||
unit_dim_deps = unit_to_unit_dim_deps(unit)
|
||||
if unit_dim_deps is not None:
|
||||
|
|
Loading…
Reference in New Issue