fix: use-after-free in socket pruner
Also the usual batch of improvements. Differentiability is misbehaving, intriguingly.main
parent
38e70a60d3
commit
c286d65193
|
@ -325,8 +325,16 @@ class FuncFlow(pyd.BaseModel):
|
||||||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||||
{}
|
{}
|
||||||
),
|
),
|
||||||
|
disallow_jax: bool = True,
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
|
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
params: The parameter-tracking object from which function arguments will be computed.
|
||||||
|
symbol_values: Values for all `SimSymbol`s that are not yet realized in `params`.
|
||||||
|
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
|
||||||
|
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
|
||||||
|
"""
|
||||||
if self.supports_jax:
|
if self.supports_jax:
|
||||||
return self.func_jax(
|
return self.func_jax(
|
||||||
*params.scaled_func_args(symbol_values),
|
*params.scaled_func_args(symbol_values),
|
||||||
|
|
|
@ -374,19 +374,22 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
"""
|
"""
|
||||||
node_tree = self.id_data
|
node_tree = self.id_data
|
||||||
for direc in ['input', 'output']:
|
for direc in ['input', 'output']:
|
||||||
bl_sockets = self._bl_sockets(direc)
|
active_socket_nametype = {
|
||||||
|
bl_socket.name: bl_socket.socket_type
|
||||||
|
for bl_socket in self._bl_sockets(direc)
|
||||||
|
}
|
||||||
active_socket_defs = self.active_socket_defs(direc)
|
active_socket_defs = self.active_socket_defs(direc)
|
||||||
|
|
||||||
# Determine Sockets to Remove
|
# Determine Sockets to Remove
|
||||||
## -> Name: If the existing socket name isn't "active".
|
## -> Name: If the existing socket name isn't "active".
|
||||||
## -> Type: If the existing socket_type != "active" SocketDef.
|
## -> Type: If the existing socket_type != "active" SocketDef.
|
||||||
bl_sockets_to_remove = [
|
bl_sockets_to_remove = [
|
||||||
bl_socket
|
active_sckname
|
||||||
for socket_name, bl_socket in bl_sockets.items()
|
for active_sckname, active_scktype in active_socket_nametype.items()
|
||||||
if (
|
if (
|
||||||
socket_name not in active_socket_defs
|
active_sckname not in active_socket_defs
|
||||||
or bl_socket.socket_type
|
or active_scktype
|
||||||
is not active_socket_defs[socket_name].socket_type
|
is not active_socket_defs[active_sckname].socket_type
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -394,39 +397,50 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
## -> Name: If the existing socket name is "active".
|
## -> Name: If the existing socket name is "active".
|
||||||
## -> Type: If the existing socket_type == "active" SocketDef.
|
## -> Type: If the existing socket_type == "active" SocketDef.
|
||||||
## -> Compare: If the existing socket differs from the SocketDef.
|
## -> Compare: If the existing socket differs from the SocketDef.
|
||||||
|
## -> NOTE: Reload bl_sockets in case to-update scks were removed.
|
||||||
bl_sockets_to_update = [
|
bl_sockets_to_update = [
|
||||||
bl_socket
|
active_sckname
|
||||||
for socket_name, bl_socket in bl_sockets.items()
|
for active_sckname, active_scktype in active_socket_nametype.items()
|
||||||
if (
|
if (
|
||||||
socket_name in active_socket_defs
|
active_sckname in active_socket_defs
|
||||||
and bl_socket.socket_type
|
and active_scktype is active_socket_defs[active_sckname].socket_type
|
||||||
is active_socket_defs[socket_name].socket_type
|
and not active_socket_defs[active_sckname].compare(
|
||||||
and not active_socket_defs[socket_name].compare(bl_socket)
|
self._bl_sockets(direc)[active_sckname]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove Sockets
|
# Remove Sockets
|
||||||
for bl_socket in bl_sockets_to_remove:
|
## -> The symptom of using a deleted socket is... hard crash.
|
||||||
bl_socket_name = bl_socket.name
|
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
|
||||||
|
## -> The multi-stage for-loop helps us guard from deleted sockets.
|
||||||
|
for active_sckname in bl_sockets_to_remove:
|
||||||
|
bl_socket = self._bl_sockets(direc).get(active_sckname)
|
||||||
|
|
||||||
# 1. Report the socket removal to the NodeTree.
|
# 1. Report the socket removal to the NodeTree.
|
||||||
## -> The NodeLinkCache needs to be adjusted manually.
|
## -> The NodeLinkCache needs to be adjusted manually.
|
||||||
node_tree.on_node_socket_removed(bl_socket)
|
node_tree.on_node_socket_removed(bl_socket)
|
||||||
|
|
||||||
|
for active_sckname in bl_sockets_to_remove:
|
||||||
|
bl_sockets = self._bl_sockets(direc)
|
||||||
|
bl_socket = bl_sockets.get(active_sckname)
|
||||||
|
|
||||||
# 2. Perform the removal using Blender's API.
|
# 2. Perform the removal using Blender's API.
|
||||||
## -> Actually removes the socket.
|
## -> Actually removes the socket.
|
||||||
bl_sockets.remove(bl_socket)
|
## -> Must be protected from auto-removed use-after-free.
|
||||||
|
if bl_socket is not None:
|
||||||
# 3. Invalidate the input socket cache across all kinds.
|
bl_sockets.remove(bl_socket)
|
||||||
## -> Prevents phantom values from remaining available.
|
|
||||||
## -> Done after socket removal to protect from race condition.
|
|
||||||
self._compute_input.invalidate(
|
|
||||||
input_socket_name=bl_socket_name,
|
|
||||||
kind=...,
|
|
||||||
unit_system=...,
|
|
||||||
)
|
|
||||||
|
|
||||||
if direc == 'input':
|
if direc == 'input':
|
||||||
|
# 3. Invalidate the input socket cache across all kinds.
|
||||||
|
## -> Prevents phantom values from remaining available.
|
||||||
|
## -> Done after socket removal to protect from race condition.
|
||||||
|
self._compute_input.invalidate(
|
||||||
|
input_socket_name=active_sckname,
|
||||||
|
kind=...,
|
||||||
|
unit_system=...,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Run all trigger-only `on_value_changed` callbacks.
|
# 4. Run all trigger-only `on_value_changed` callbacks.
|
||||||
## -> Runs any event methods that relied on the socket.
|
## -> Runs any event methods that relied on the socket.
|
||||||
## -> Only methods that don't **require** the socket.
|
## -> Only methods that don't **require** the socket.
|
||||||
|
@ -435,39 +449,69 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
||||||
triggered_event_methods = [
|
triggered_event_methods = [
|
||||||
event_method
|
event_method
|
||||||
for event_method in self.filtered_event_methods_by_event(
|
for event_method in self.filtered_event_methods_by_event(
|
||||||
ct.FlowEvent.DataChanged, (bl_socket_name, None, None)
|
ct.FlowEvent.DataChanged, (active_sckname, None, None)
|
||||||
)
|
)
|
||||||
if bl_socket_name
|
if active_sckname
|
||||||
not in event_method.callback_info.must_load_sockets
|
not in event_method.callback_info.must_load_sockets
|
||||||
]
|
]
|
||||||
for event_method in triggered_event_methods:
|
for event_method in triggered_event_methods:
|
||||||
event_method(self)
|
event_method(self)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 3. Invalidate the output socket cache across all kinds.
|
||||||
|
## -> Prevents phantom values from remaining available.
|
||||||
|
## -> Done after socket removal to protect from race condition.
|
||||||
|
self.compute_output.invalidate(
|
||||||
|
input_socket_name=active_sckname,
|
||||||
|
kind=...,
|
||||||
|
)
|
||||||
|
|
||||||
# Update Sockets
|
# Update Sockets
|
||||||
for bl_socket in bl_sockets_to_update:
|
## -> The symptom of using a deleted socket is... hard crash.
|
||||||
bl_socket_name = bl_socket.name
|
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
|
||||||
socket_def = active_socket_defs[bl_socket_name]
|
## -> The multi-stage for-loop helps us guard from deleted sockets.
|
||||||
|
for active_sckname in bl_sockets_to_update:
|
||||||
|
bl_sockets = self._bl_sockets(direc)
|
||||||
|
bl_socket = bl_sockets.get(active_sckname)
|
||||||
|
|
||||||
# 1. Pretend to Initialize for the First Time
|
if bl_socket is not None:
|
||||||
## -> NOTE: The socket's caches will be completely regenerated.
|
socket_def = active_socket_defs[active_sckname]
|
||||||
## -> NOTE: A full FlowKind update will occur, but only one.
|
|
||||||
bl_socket.is_initializing = True
|
|
||||||
socket_def.preinit(bl_socket)
|
|
||||||
socket_def.init(bl_socket)
|
|
||||||
socket_def.postinit(bl_socket)
|
|
||||||
|
|
||||||
# 2. Re-Test Socket Capabilities
|
# 1. Pretend to Initialize for the First Time
|
||||||
## -> Factors influencing CapabilitiesFlow may have changed.
|
## -> NOTE: The socket's caches will be completely regenerated.
|
||||||
## -> Therefore, we must re-test all link capabilities.
|
## -> NOTE: A full FlowKind update will occur, but only one.
|
||||||
bl_socket.remove_invalidated_links()
|
bl_socket.is_initializing = True
|
||||||
|
socket_def.preinit(bl_socket)
|
||||||
|
socket_def.init(bl_socket)
|
||||||
|
socket_def.postinit(bl_socket)
|
||||||
|
|
||||||
# 3. Invalidate the input socket cache across all kinds.
|
for active_sckname in bl_sockets_to_update:
|
||||||
## -> Prevents phantom values from remaining available.
|
bl_sockets = self._bl_sockets(direc)
|
||||||
self._compute_input.invalidate(
|
bl_socket = bl_sockets.get(active_sckname)
|
||||||
input_socket_name=bl_socket_name,
|
|
||||||
kind=...,
|
if bl_socket is not None:
|
||||||
unit_system=...,
|
# 2. Re-Test Socket Capabilities
|
||||||
)
|
## -> Factors influencing CapabilitiesFlow may have changed.
|
||||||
|
## -> Therefore, we must re-test all link capabilities.
|
||||||
|
bl_socket.remove_invalidated_links()
|
||||||
|
|
||||||
|
if direc == 'input':
|
||||||
|
# 3. Invalidate the input socket cache across all kinds.
|
||||||
|
## -> Prevents phantom values from remaining available.
|
||||||
|
self._compute_input.invalidate(
|
||||||
|
input_socket_name=active_sckname,
|
||||||
|
kind=...,
|
||||||
|
unit_system=...,
|
||||||
|
)
|
||||||
|
|
||||||
|
if direc == 'output':
|
||||||
|
# 3. Invalidate the output socket cache across all kinds.
|
||||||
|
## -> Prevents phantom values from remaining available.
|
||||||
|
## -> Done after socket removal to protect from race condition.
|
||||||
|
self.compute_output.invalidate(
|
||||||
|
input_socket_name=active_sckname,
|
||||||
|
kind=...,
|
||||||
|
)
|
||||||
|
|
||||||
def _add_new_active_sockets(self):
|
def _add_new_active_sockets(self):
|
||||||
"""Add and initialize all "active" sockets that aren't on the node.
|
"""Add and initialize all "active" sockets that aren't on the node.
|
||||||
|
|
|
@ -104,6 +104,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
||||||
interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
||||||
|
|
||||||
|
preview_value_z: int = bl_cache.BLField(0)
|
||||||
|
preview_value_q: tuple[int, int] = bl_cache.BLField((0, 1))
|
||||||
|
preview_value_re: float = bl_cache.BLField(0.0)
|
||||||
|
preview_value_im: float = bl_cache.BLField(0.0)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Computed Properties
|
# - Computed Properties
|
||||||
####################
|
####################
|
||||||
|
@ -122,6 +127,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
'interval_finite_im',
|
'interval_finite_im',
|
||||||
'interval_inf_im',
|
'interval_inf_im',
|
||||||
'interval_closed_im',
|
'interval_closed_im',
|
||||||
|
'preview_value_z',
|
||||||
|
'preview_value_q',
|
||||||
|
'preview_value_re',
|
||||||
|
'preview_value_im',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def symbol(self) -> sim_symbols.SimSymbol:
|
def symbol(self) -> sim_symbols.SimSymbol:
|
||||||
|
@ -140,6 +149,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
interval_finite_im=self.interval_finite_im,
|
interval_finite_im=self.interval_finite_im,
|
||||||
interval_inf_im=self.interval_inf_im,
|
interval_inf_im=self.interval_inf_im,
|
||||||
interval_closed_im=self.interval_closed_im,
|
interval_closed_im=self.interval_closed_im,
|
||||||
|
preview_value_z=self.preview_value_z,
|
||||||
|
preview_value_q=self.preview_value_q,
|
||||||
|
preview_value_re=self.preview_value_re,
|
||||||
|
preview_value_im=self.preview_value_im,
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -164,36 +177,68 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
||||||
|
|
||||||
col.prop(self, self.blfields['physical_type'], text='')
|
col.prop(self, self.blfields['physical_type'], text='')
|
||||||
|
|
||||||
|
# Domain - Infinite
|
||||||
row = col.row(align=True)
|
row = col.row(align=True)
|
||||||
row.alignment = 'CENTER'
|
row.alignment = 'CENTER'
|
||||||
row.label(text='Domain')
|
row.label(text='Domain - Is Infinite')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
if self.mathtype is spux.MathType.Complex:
|
||||||
|
row.prop(self, self.blfields['interval_inf'], text='ℝ')
|
||||||
|
row.prop(self, self.blfields['interval_inf_im'], text='𝕀')
|
||||||
|
else:
|
||||||
|
row.prop(self, self.blfields['interval_inf'], text='')
|
||||||
|
|
||||||
|
if any(not b for b in self.interval_inf):
|
||||||
|
# Domain - Closure
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Domain - Closure')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
if self.mathtype is spux.MathType.Complex:
|
||||||
|
row.prop(self, self.blfields['interval_closed'], text='ℝ')
|
||||||
|
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
|
||||||
|
else:
|
||||||
|
row.prop(self, self.blfields['interval_closed'], text='')
|
||||||
|
|
||||||
|
# Domain - Finite
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Domain - Interval')
|
||||||
|
|
||||||
|
row = col.row(align=True)
|
||||||
|
match self.mathtype:
|
||||||
|
case spux.MathType.Integer:
|
||||||
|
row.prop(self, self.blfields['interval_finite_z'], text='')
|
||||||
|
|
||||||
|
case spux.MathType.Rational:
|
||||||
|
row.prop(self, self.blfields['interval_finite_q'], text='')
|
||||||
|
|
||||||
|
case spux.MathType.Real:
|
||||||
|
row.prop(self, self.blfields['interval_finite_re'], text='')
|
||||||
|
|
||||||
|
case spux.MathType.Complex:
|
||||||
|
row.prop(self, self.blfields['interval_finite_re'], text='ℝ')
|
||||||
|
row.prop(self, self.blfields['interval_finite_im'], text='𝕀')
|
||||||
|
|
||||||
|
# Domain - Closure
|
||||||
|
row = col.row(align=True)
|
||||||
|
row.alignment = 'CENTER'
|
||||||
|
row.label(text='Preview Value')
|
||||||
match self.mathtype:
|
match self.mathtype:
|
||||||
case spux.MathType.Integer:
|
case spux.MathType.Integer:
|
||||||
col.prop(self, self.blfields['interval_finite_z'], text='')
|
row.prop(self, self.blfields['preview_value_z'], text='')
|
||||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
|
||||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
|
||||||
|
|
||||||
case spux.MathType.Rational:
|
case spux.MathType.Rational:
|
||||||
col.prop(self, self.blfields['interval_finite_q'], text='')
|
row.prop(self, self.blfields['preview_value_q'], text='')
|
||||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
|
||||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
|
||||||
|
|
||||||
case spux.MathType.Real:
|
case spux.MathType.Real:
|
||||||
col.prop(self, self.blfields['interval_finite_re'], text='')
|
row.prop(self, self.blfields['preview_value_re'], text='')
|
||||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
|
||||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
|
||||||
|
|
||||||
case spux.MathType.Complex:
|
case spux.MathType.Complex:
|
||||||
col.prop(self, self.blfields['interval_finite_re'], text='ℝ')
|
row.prop(self, self.blfields['preview_value_re'], text='ℝ')
|
||||||
col.prop(self, self.blfields['interval_inf'], text='ℝ Infinite')
|
row.prop(self, self.blfields['preview_value_im'], text='𝕀')
|
||||||
col.prop(self, self.blfields['interval_closed'], text='ℝ Closed')
|
|
||||||
|
|
||||||
col.separator()
|
|
||||||
|
|
||||||
col.prop(self, self.blfields['interval_finite_im'], text='𝕀')
|
|
||||||
col.prop(self, self.blfields['interval_inf'], text='𝕀 Infinite')
|
|
||||||
col.prop(self, self.blfields['interval_closed'], text='𝕀 Closed')
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - FlowKinds
|
# - FlowKinds
|
||||||
|
|
|
@ -132,7 +132,7 @@ class SimDomainNode(base.MaxwellSimNode):
|
||||||
| grid
|
| grid
|
||||||
| medium
|
| medium
|
||||||
).compose_within(
|
).compose_within(
|
||||||
enclosing_func=lambda els: {
|
lambda els: {
|
||||||
'run_time': els[0],
|
'run_time': els[0],
|
||||||
'center': tuple(els[1].flatten()),
|
'center': tuple(els[1].flatten()),
|
||||||
'size': tuple(els[2].flatten()),
|
'size': tuple(els[2].flatten()),
|
||||||
|
|
|
@ -88,36 +88,19 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'Structure',
|
'Structure',
|
||||||
kind=ct.FlowKind.Value,
|
kind=ct.FlowKind.Value,
|
||||||
# Loaded
|
# Loaded
|
||||||
input_sockets={'Medium', 'Center', 'Size'},
|
|
||||||
output_sockets={'Structure'},
|
output_sockets={'Structure'},
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||||
)
|
)
|
||||||
def compute_value(self, input_sockets, output_sockets) -> td.Box:
|
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
"""Compute a single box structure object, given that all inputs are non-symbolic."""
|
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||||
center = input_sockets['Center']
|
output_func = output_sockets['Structure'][ct.FlowKind.Func]
|
||||||
size = input_sockets['Size']
|
output_params = output_sockets['Structure'][ct.FlowKind.Params]
|
||||||
medium = input_sockets['Medium']
|
|
||||||
output_params = output_sockets['Structure']
|
|
||||||
|
|
||||||
has_center = not ct.FlowSignal.check(center)
|
has_output_func = not ct.FlowSignal.check(output_func)
|
||||||
has_size = not ct.FlowSignal.check(size)
|
|
||||||
has_medium = not ct.FlowSignal.check(medium)
|
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
if (
|
if has_output_func and has_output_params and not output_params.symbols:
|
||||||
has_center
|
return output_func.realize(output_params, disallow_jax=True)
|
||||||
and has_size
|
|
||||||
and has_medium
|
|
||||||
and has_output_params
|
|
||||||
and not output_params.symbols
|
|
||||||
):
|
|
||||||
return td.Structure(
|
|
||||||
geometry=td.Box(
|
|
||||||
center=spux.scale_to_unit_system(center, ct.UNITS_TIDY3D),
|
|
||||||
size=spux.scale_to_unit_system(size, ct.UNITS_TIDY3D),
|
|
||||||
),
|
|
||||||
medium=medium,
|
|
||||||
)
|
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -134,45 +117,43 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'Center': ct.FlowKind.Func,
|
'Center': ct.FlowKind.Func,
|
||||||
'Size': ct.FlowKind.Func,
|
'Size': ct.FlowKind.Func,
|
||||||
},
|
},
|
||||||
output_sockets={'Structure'},
|
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
|
||||||
)
|
)
|
||||||
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
|
def compute_structure_func(self, props, input_sockets) -> td.Box:
|
||||||
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
|
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
|
||||||
output_params = output_sockets['Structure']
|
|
||||||
center = input_sockets['Center']
|
center = input_sockets['Center']
|
||||||
size = input_sockets['Size']
|
size = input_sockets['Size']
|
||||||
medium = input_sockets['Medium']
|
medium = input_sockets['Medium']
|
||||||
|
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
|
||||||
has_center = not ct.FlowSignal.check(center)
|
has_center = not ct.FlowSignal.check(center)
|
||||||
has_size = not ct.FlowSignal.check(size)
|
has_size = not ct.FlowSignal.check(size)
|
||||||
has_medium = not ct.FlowSignal.check(medium)
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
if has_output_params and has_center and has_size and has_medium:
|
if has_center and has_size and has_medium:
|
||||||
differentiable = props['differentiable']
|
differentiable = props['differentiable']
|
||||||
if differentiable:
|
if differentiable:
|
||||||
return (center | size | medium).compose_within(
|
return (
|
||||||
enclosing_func=lambda els: tdadj.JaxStructure(
|
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| medium
|
||||||
|
).compose_within(
|
||||||
|
lambda els: tdadj.JaxStructure(
|
||||||
geometry=tdadj.JaxBox(
|
geometry=tdadj.JaxBox(
|
||||||
center=tuple(els[0].flatten()),
|
center_jax=tuple(els[0].flatten()),
|
||||||
size=tuple(els[1].flatten()),
|
size_jax=tuple(els[1].flatten()),
|
||||||
),
|
),
|
||||||
medium=els[2],
|
medium=els[2],
|
||||||
),
|
),
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
return (center | size | medium).compose_within(
|
return (
|
||||||
## TODO: Unit conversion within the composed function??
|
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
## -- We do need Tidy3D to be given ex. micrometers in particular.
|
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
## -- But the previous numerical output might not be micrometers.
|
| medium
|
||||||
## -- There must be a way to add a conversion in, without strangeness.
|
).compose_within(
|
||||||
## -- Ex. can compose_within() take a unit system?
|
lambda els: td.Structure(
|
||||||
## -- This would require
|
|
||||||
enclosing_func=lambda els: td.Structure(
|
|
||||||
geometry=td.Box(
|
geometry=td.Box(
|
||||||
center=tuple(els[0].flatten()),
|
center=els[0].flatten().tolist(),
|
||||||
size=tuple(els[1].flatten()),
|
size=els[1].flatten().tolist(),
|
||||||
),
|
),
|
||||||
medium=els[2],
|
medium=els[2],
|
||||||
),
|
),
|
||||||
|
@ -187,7 +168,6 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'Structure',
|
'Structure',
|
||||||
kind=ct.FlowKind.Params,
|
kind=ct.FlowKind.Params,
|
||||||
# Loaded
|
# Loaded
|
||||||
props={'differentiable'},
|
|
||||||
input_sockets={'Medium', 'Center', 'Size'},
|
input_sockets={'Medium', 'Center', 'Size'},
|
||||||
input_socket_kinds={
|
input_socket_kinds={
|
||||||
'Medium': ct.FlowKind.Params,
|
'Medium': ct.FlowKind.Params,
|
||||||
|
@ -195,7 +175,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'Size': ct.FlowKind.Params,
|
'Size': ct.FlowKind.Params,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_params(self, props, input_sockets) -> td.Box:
|
def compute_params(self, input_sockets) -> td.Box:
|
||||||
center = input_sockets['Center']
|
center = input_sockets['Center']
|
||||||
size = input_sockets['Size']
|
size = input_sockets['Size']
|
||||||
medium = input_sockets['Medium']
|
medium = input_sockets['Medium']
|
||||||
|
@ -238,13 +218,16 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
output_sockets={'Structure'},
|
output_sockets={'Structure'},
|
||||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||||
)
|
)
|
||||||
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
|
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
|
||||||
output_params = output_sockets['Structure']
|
|
||||||
center = input_sockets['Center']
|
center = input_sockets['Center']
|
||||||
|
size = input_sockets['Size']
|
||||||
|
output_params = output_sockets['Structure']
|
||||||
|
|
||||||
has_output_params = not ct.FlowSignal.check(output_params)
|
|
||||||
has_center = not ct.FlowSignal.check(center)
|
has_center = not ct.FlowSignal.check(center)
|
||||||
if has_center and has_output_params and not output_params.symbols:
|
has_size = not ct.FlowSignal.check(size)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if has_center and has_size and has_output_params and not output_params.symbols:
|
||||||
## TODO: There are strategies for handling examples of symbol values.
|
## TODO: There are strategies for handling examples of symbol values.
|
||||||
|
|
||||||
# Push Loose Input Values to GeoNodes Modifier
|
# Push Loose Input Values to GeoNodes Modifier
|
||||||
|
@ -254,7 +237,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
||||||
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
|
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
|
||||||
'unit_system': ct.UNITS_BLENDER,
|
'unit_system': ct.UNITS_BLENDER,
|
||||||
'inputs': {
|
'inputs': {
|
||||||
'Size': input_sockets['Size'],
|
'Size': size,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||||
|
|
|
@ -32,6 +32,8 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CylinderStructureNode(base.MaxwellSimNode):
|
class CylinderStructureNode(base.MaxwellSimNode):
|
||||||
|
"""A generic cylinder structure with configurable radius and height."""
|
||||||
|
|
||||||
node_type = ct.NodeType.CylinderStructure
|
node_type = ct.NodeType.CylinderStructure
|
||||||
bl_label = 'Cylinder Structure'
|
bl_label = 'Cylinder Structure'
|
||||||
use_sim_node_name = True
|
use_sim_node_name = True
|
||||||
|
@ -64,27 +66,101 @@ class CylinderStructureNode(base.MaxwellSimNode):
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Output Socket Computation
|
# - FlowKind.Value
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Structure',
|
'Structure',
|
||||||
|
kind=ct.FlowKind.Value,
|
||||||
|
# Loaded
|
||||||
|
output_sockets={'Structure'},
|
||||||
|
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||||
|
)
|
||||||
|
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
|
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||||
|
output_func = output_sockets['Structure'][ct.FlowKind.Func]
|
||||||
|
output_params = output_sockets['Structure'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
has_output_func = not ct.FlowSignal.check(output_func)
|
||||||
|
has_output_params = not ct.FlowSignal.check(output_params)
|
||||||
|
|
||||||
|
if has_output_func and has_output_params and not output_params.symbols:
|
||||||
|
return output_func.realize(output_params, disallow_jax=True)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Func
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Structure',
|
||||||
|
kind=ct.FlowKind.Func,
|
||||||
|
# Loaded
|
||||||
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
|
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
|
||||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
input_socket_kinds={
|
||||||
scale_input_sockets={
|
'Center': ct.FlowKind.Func,
|
||||||
'Center': 'Tidy3DUnits',
|
'Radius': ct.FlowKind.Func,
|
||||||
'Radius': 'Tidy3DUnits',
|
'Height': ct.FlowKind.Func,
|
||||||
'Height': 'Tidy3DUnits',
|
'Medium': ct.FlowKind.Func,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_structure(self, input_sockets, unit_systems) -> td.Box:
|
def compute_func(self, input_sockets) -> td.Box:
|
||||||
return td.Structure(
|
"""Compute a single cylinder structure object, given that all inputs are non-symbolic."""
|
||||||
geometry=td.Cylinder(
|
center = input_sockets['Center']
|
||||||
radius=input_sockets['Radius'],
|
radius = input_sockets['Radius']
|
||||||
center=input_sockets['Center'],
|
height = input_sockets['Height']
|
||||||
length=input_sockets['Height'],
|
medium = input_sockets['Medium']
|
||||||
),
|
|
||||||
medium=input_sockets['Medium'],
|
has_center = not ct.FlowSignal.check(center)
|
||||||
)
|
has_radius = not ct.FlowSignal.check(radius)
|
||||||
|
has_height = not ct.FlowSignal.check(height)
|
||||||
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
|
if has_center and has_radius and has_height and has_medium:
|
||||||
|
return (
|
||||||
|
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| radius.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| height.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||||
|
| medium
|
||||||
|
).compose_within(
|
||||||
|
lambda els: td.Structure(
|
||||||
|
geometry=td.Cylinder(
|
||||||
|
center=els[0].flatten().tolist(),
|
||||||
|
radius=els[1],
|
||||||
|
length=els[2],
|
||||||
|
),
|
||||||
|
medium=els[3],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - FlowKind.Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Structure',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
# Loaded
|
||||||
|
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Center': ct.FlowKind.Params,
|
||||||
|
'Radius': ct.FlowKind.Params,
|
||||||
|
'Height': ct.FlowKind.Params,
|
||||||
|
'Medium': ct.FlowKind.Params,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_params(self, input_sockets) -> td.Box:
|
||||||
|
center = input_sockets['Center']
|
||||||
|
radius = input_sockets['Radius']
|
||||||
|
height = input_sockets['Height']
|
||||||
|
medium = input_sockets['Medium']
|
||||||
|
|
||||||
|
has_center = not ct.FlowSignal.check(center)
|
||||||
|
has_radius = not ct.FlowSignal.check(radius)
|
||||||
|
has_height = not ct.FlowSignal.check(height)
|
||||||
|
has_medium = not ct.FlowSignal.check(medium)
|
||||||
|
|
||||||
|
if has_center and has_radius and has_height and has_medium:
|
||||||
|
return center | radius | height | medium
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Preview
|
# - Preview
|
||||||
|
|
|
@ -791,14 +791,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
####################
|
####################
|
||||||
# - FlowKind: Func (w/Params if Constant)
|
# - FlowKind: Func (w/Params if Constant)
|
||||||
####################
|
####################
|
||||||
@bl_cache.cached_bl_property(
|
@bl_cache.cached_bl_property(depends_on={'output_sym'})
|
||||||
depends_on={
|
|
||||||
'value',
|
|
||||||
'sorted_sp_symbols',
|
|
||||||
'sorted_symbols',
|
|
||||||
'output_sym',
|
|
||||||
}
|
|
||||||
)
|
|
||||||
def lazy_func(self) -> ct.FuncFlow:
|
def lazy_func(self) -> ct.FuncFlow:
|
||||||
"""Returns a lazy value that computes the expression returned by `self.value`.
|
"""Returns a lazy value that computes the expression returned by `self.value`.
|
||||||
|
|
||||||
|
@ -806,42 +799,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
||||||
"""
|
"""
|
||||||
if self.output_sym is not None:
|
if self.output_sym is not None:
|
||||||
match self.active_kind:
|
return ct.FuncFlow(
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
func=lambda v: v,
|
||||||
self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
func_args=[self.output_sym],
|
||||||
):
|
func_output=self.output_sym,
|
||||||
return ct.FuncFlow(
|
supports_jax=True,
|
||||||
func=sp.lambdify(
|
)
|
||||||
self.sorted_sp_symbols,
|
|
||||||
self.output_sym.conform(self.value, strip_unit=True),
|
|
||||||
'jax',
|
|
||||||
),
|
|
||||||
func_args=list(self.sorted_symbols),
|
|
||||||
func_output=self.output_sym,
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
|
|
||||||
return ct.FuncFlow(
|
|
||||||
func=lambda v: v,
|
|
||||||
func_args=[self.output_sym],
|
|
||||||
func_output=self.output_sym,
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
case ct.FlowKind.Range if self.sorted_symbols:
|
|
||||||
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
case ct.FlowKind.Range if (
|
|
||||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
|
||||||
):
|
|
||||||
return ct.FuncFlow(
|
|
||||||
func=lambda v: v,
|
|
||||||
func_args=[self.output_sym],
|
|
||||||
func_output=self.output_sym,
|
|
||||||
supports_jax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
@ -854,22 +817,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
||||||
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
||||||
"""
|
"""
|
||||||
output_sym = self.output_sym
|
if self.output_sym is not None:
|
||||||
if output_sym is not None:
|
|
||||||
match self.active_kind:
|
match self.active_kind:
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
|
||||||
return ct.ParamsFlow(
|
|
||||||
arg_targets=list(self.sorted_symbols),
|
|
||||||
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
|
|
||||||
symbols=set(self.sorted_symbols),
|
|
||||||
)
|
|
||||||
|
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||||
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
not ct.FlowSignal.check(self.value)
|
||||||
):
|
):
|
||||||
return ct.ParamsFlow(
|
return ct.ParamsFlow(
|
||||||
arg_targets=[self.output_sym],
|
arg_targets=[self.output_sym],
|
||||||
func_args=[self.value],
|
func_args=[self.value],
|
||||||
|
symbols=set(self.sorted_symbols),
|
||||||
)
|
)
|
||||||
|
|
||||||
case ct.FlowKind.Range if self.sorted_symbols:
|
case ct.FlowKind.Range if self.sorted_symbols:
|
||||||
|
@ -899,20 +855,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
|
|
||||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||||
"""
|
"""
|
||||||
output_sym = self.output_sym
|
if self.output_sym is not None:
|
||||||
if output_sym is not None:
|
|
||||||
match self.active_kind:
|
match self.active_kind:
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
case ct.FlowKind.Value | ct.FlowKind.Func:
|
||||||
return ct.InfoFlow(
|
return ct.InfoFlow(
|
||||||
dims={sym: None for sym in self.sorted_symbols},
|
dims={sym: None for sym in self.sorted_symbols},
|
||||||
output=self.output_sym,
|
output=self.output_sym,
|
||||||
)
|
)
|
||||||
|
|
||||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
|
||||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
|
||||||
):
|
|
||||||
return ct.InfoFlow(output=self.output_sym)
|
|
||||||
|
|
||||||
case ct.FlowKind.Range if self.sorted_symbols:
|
case ct.FlowKind.Range if self.sorted_symbols:
|
||||||
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
|
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
|
@ -14,7 +14,10 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
|
import jax.numpy as jnp
|
||||||
import scipy as sc
|
import scipy as sc
|
||||||
import sympy.physics.units as spu
|
import sympy.physics.units as spu
|
||||||
import tidy3d as td
|
import tidy3d as td
|
||||||
|
@ -28,9 +31,15 @@ from .. import base
|
||||||
|
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
|
_VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
|
||||||
|
_FIXED_WL = 500 * spu.nm
|
||||||
FIXED_WL = 500 * spu.nm
|
FIXED_FREQ = spux.scale_to_unit_system(
|
||||||
|
spu.convert_to(
|
||||||
|
_VAC_SPEED_OF_LIGHT / _FIXED_WL,
|
||||||
|
spu.hertz,
|
||||||
|
),
|
||||||
|
ct.UNITS_TIDY3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
||||||
|
@ -49,24 +58,17 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
||||||
####################
|
####################
|
||||||
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
|
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
|
||||||
def value(self) -> td.Medium:
|
def value(self) -> td.Medium:
|
||||||
freq = (
|
eps_r_re = self.eps_rel[0]
|
||||||
spu.convert_to(
|
conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
|
||||||
VAC_SPEED_OF_LIGHT / FIXED_WL,
|
|
||||||
spu.hertz,
|
|
||||||
)
|
|
||||||
/ spu.hertz
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.differentiable:
|
if self.differentiable:
|
||||||
return tdadj.JaxMedium.from_nk(
|
return tdadj.JaxMedium(
|
||||||
n=self.eps_rel[0],
|
permittivity_jax=jnp.array(eps_r_re, dtype=float),
|
||||||
k=self.eps_rel[1],
|
conductivity_jax=jnp.array(conductivity, dtype=float),
|
||||||
freq=freq,
|
|
||||||
)
|
)
|
||||||
return td.Medium.from_nk(
|
return td.Medium(
|
||||||
n=self.eps_rel[0],
|
permittivity=eps_r_re,
|
||||||
k=self.eps_rel[1],
|
conductivity=conductivity,
|
||||||
freq=freq,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@value.setter
|
@value.setter
|
||||||
|
|
|
@ -263,6 +263,11 @@ class SimSymbol(pyd.BaseModel):
|
||||||
interval_inf_im: tuple[bool, bool] = (True, True)
|
interval_inf_im: tuple[bool, bool] = (True, True)
|
||||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||||
|
|
||||||
|
preview_value_z: int = 0
|
||||||
|
preview_value_q: tuple[int, int] = (0, 1)
|
||||||
|
preview_value_re: float = 0.0
|
||||||
|
preview_value_im: float = 0.0
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Core
|
# - Core
|
||||||
####################
|
####################
|
||||||
|
|
Loading…
Reference in New Issue