fix: sim Value single realizations

main
Sofus Albert Høgsbro Rose 2024-05-31 15:57:25 +02:00
parent 02e309db4d
commit 572d53f41e
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
5 changed files with 45 additions and 17 deletions

View File

@ -322,7 +322,9 @@ class BlochBoundCondNode(base.MaxwellSimNode):
if has_bloch_vector:
return bloch_vector.compose_within(
enclosing_func=lambda: td.BlochBoundary(bloch_vec=bloch_vector),
enclosing_func=lambda _bloch_vector: td.BlochBoundary(
bloch_vec=_bloch_vector
),
supports_jax=False,
)
return ct.FlowSignal.FlowPending

View File

@ -162,10 +162,15 @@ class CombineNode(base.MaxwellSimNode):
for inp in loose_input_sockets.values()
if not ct.FlowSignal.check(inp)
]
if func_flows:
if len(func_flows) > 1:
return functools.reduce(
lambda a, b: a | b, func_flows
).compose_within(lambda els: list(els))
if len(func_flows) == 1:
return func_flows[0].compose_within(lambda el: [el])
return ct.FlowSignal.FlowPending
case (ct.FlowKind.Func, ct.FlowKind.Params):

View File

@ -83,8 +83,8 @@ class FDTDSimNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
# Trigger
socket_name={'Sources', 'Structures', 'Domain', 'BCs', 'Monitors'},
run_on_init=True,
socket_name={'BCs', 'Domain', 'Sources', 'Structures', 'Monitors'},
prop_name={'active_socket_set'},
# Loaded
props={'active_socket_set'},
output_sockets={'Sim'},
@ -92,14 +92,18 @@ class FDTDSimNode(base.MaxwellSimNode):
)
def on_any_changed(self, props, output_sockets) -> None:
"""Create loose input sockets."""
params = output_sockets['Sim']
has_params = not ct.FlowSignal.check(params)
output_params = output_sockets['Sim']
has_output_params = not ct.FlowSignal.check(output_params)
# Declare Loose Sockets that Realize Symbols
## -> This happens if Params contains not-yet-realized symbols.
active_socket_set = props['active_socket_set']
if active_socket_set == 'Single' and has_params and params.symbols:
if set(self.loose_input_sockets) != {sym.name for sym in params.symbols}:
if (
active_socket_set == 'Single'
and has_output_params
and output_params.symbols
):
if set(self.loose_input_sockets) != {
sym.name for sym in output_params.symbols
}:
self.loose_input_sockets = {
sym.name: sockets.ExprSocketDef(
**(
@ -112,7 +116,7 @@ class FDTDSimNode(base.MaxwellSimNode):
}
)
)
for sym, expr_info in params.sym_expr_infos.items()
for sym, expr_info in output_params.sym_expr_infos.items()
}
elif self.loose_input_sockets:
@ -125,10 +129,13 @@ class FDTDSimNode(base.MaxwellSimNode):
'Sim',
kind=ct.FlowKind.Value,
# Loaded
all_loose_input_sockets=True,
output_sockets={'Sim'},
output_socket_kinds={'Sim': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
def compute_value(
self, loose_input_sockets, output_sockets
) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Sim'][ct.FlowKind.Func]
output_params = output_sockets['Sim'][ct.FlowKind.Params]
@ -136,8 +143,15 @@ class FDTDSimNode(base.MaxwellSimNode):
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
if has_output_func and has_output_params:
return output_func.realize(
output_params,
symbol_values={
sym: loose_input_sockets[sym.name]
for sym in output_params.sorted_symbols
},
disallow_jax=True,
)
return ct.FlowSignal.FlowPending
####################

View File

@ -134,8 +134,8 @@ class SimDomainNode(base.MaxwellSimNode):
).compose_within(
lambda els: {
'run_time': els[0],
'center': tuple(els[1].flatten()),
'size': tuple(els[2].flatten()),
'center': els[1].flatten().tolist(),
'size': els[2].flatten().tolist(),
'grid_spec': els[3],
'medium': els[4],
},

View File

@ -68,7 +68,7 @@ class SimSymbolName(enum.StrEnum):
LowerTheta = enum.auto()
LowerPhi = enum.auto()
# Fields
# EM Fields
Ex = enum.auto()
Ey = enum.auto()
Ez = enum.auto()
@ -97,6 +97,10 @@ class SimSymbolName(enum.StrEnum):
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
BlochX = enum.auto()
BlochY = enum.auto()
BlochZ = enum.auto()
####################
# - UI
####################
@ -168,6 +172,9 @@ class SimSymbolName(enum.StrEnum):
SSN.Flux: 'flux',
SSN.DiffOrderX: 'order_x',
SSN.DiffOrderY: 'order_y',
SSN.BlochX: 'bloch_x',
SSN.BlochY: 'bloch_y',
SSN.BlochZ: 'bloch_z',
}
)[self]