fix: use-after-free in socket pruner
Also the usual batch of improvements. Differentiability is misbehaving, intriguingly.main
parent
38e70a60d3
commit
c286d65193
|
@ -325,8 +325,16 @@ class FuncFlow(pyd.BaseModel):
|
|||
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
|
||||
{}
|
||||
),
|
||||
disallow_jax: bool = True,
|
||||
) -> typ.Self:
|
||||
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
|
||||
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
|
||||
|
||||
Parameters:
|
||||
params: The parameter-tracking object from which function arguments will be computed.
|
||||
symbol_values: Values for all `SimSymbol`s that are not yet realized in `params`.
|
||||
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
|
||||
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
|
||||
"""
|
||||
if self.supports_jax:
|
||||
return self.func_jax(
|
||||
*params.scaled_func_args(symbol_values),
|
||||
|
|
|
@ -374,19 +374,22 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
"""
|
||||
node_tree = self.id_data
|
||||
for direc in ['input', 'output']:
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
active_socket_nametype = {
|
||||
bl_socket.name: bl_socket.socket_type
|
||||
for bl_socket in self._bl_sockets(direc)
|
||||
}
|
||||
active_socket_defs = self.active_socket_defs(direc)
|
||||
|
||||
# Determine Sockets to Remove
|
||||
## -> Name: If the existing socket name isn't "active".
|
||||
## -> Type: If the existing socket_type != "active" SocketDef.
|
||||
bl_sockets_to_remove = [
|
||||
bl_socket
|
||||
for socket_name, bl_socket in bl_sockets.items()
|
||||
active_sckname
|
||||
for active_sckname, active_scktype in active_socket_nametype.items()
|
||||
if (
|
||||
socket_name not in active_socket_defs
|
||||
or bl_socket.socket_type
|
||||
is not active_socket_defs[socket_name].socket_type
|
||||
active_sckname not in active_socket_defs
|
||||
or active_scktype
|
||||
is not active_socket_defs[active_sckname].socket_type
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -394,39 +397,50 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
## -> Name: If the existing socket name is "active".
|
||||
## -> Type: If the existing socket_type == "active" SocketDef.
|
||||
## -> Compare: If the existing socket differs from the SocketDef.
|
||||
## -> NOTE: Reload bl_sockets in case to-update scks were removed.
|
||||
bl_sockets_to_update = [
|
||||
bl_socket
|
||||
for socket_name, bl_socket in bl_sockets.items()
|
||||
active_sckname
|
||||
for active_sckname, active_scktype in active_socket_nametype.items()
|
||||
if (
|
||||
socket_name in active_socket_defs
|
||||
and bl_socket.socket_type
|
||||
is active_socket_defs[socket_name].socket_type
|
||||
and not active_socket_defs[socket_name].compare(bl_socket)
|
||||
active_sckname in active_socket_defs
|
||||
and active_scktype is active_socket_defs[active_sckname].socket_type
|
||||
and not active_socket_defs[active_sckname].compare(
|
||||
self._bl_sockets(direc)[active_sckname]
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Remove Sockets
|
||||
for bl_socket in bl_sockets_to_remove:
|
||||
bl_socket_name = bl_socket.name
|
||||
## -> The symptom of using a deleted socket is... hard crash.
|
||||
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
|
||||
## -> The multi-stage for-loop helps us guard from deleted sockets.
|
||||
for active_sckname in bl_sockets_to_remove:
|
||||
bl_socket = self._bl_sockets(direc).get(active_sckname)
|
||||
|
||||
# 1. Report the socket removal to the NodeTree.
|
||||
## -> The NodeLinkCache needs to be adjusted manually.
|
||||
node_tree.on_node_socket_removed(bl_socket)
|
||||
|
||||
for active_sckname in bl_sockets_to_remove:
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
bl_socket = bl_sockets.get(active_sckname)
|
||||
|
||||
# 2. Perform the removal using Blender's API.
|
||||
## -> Actually removes the socket.
|
||||
## -> Must be protected from auto-removed use-after-free.
|
||||
if bl_socket is not None:
|
||||
bl_sockets.remove(bl_socket)
|
||||
|
||||
if direc == 'input':
|
||||
# 3. Invalidate the input socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
## -> Done after socket removal to protect from race condition.
|
||||
self._compute_input.invalidate(
|
||||
input_socket_name=bl_socket_name,
|
||||
input_socket_name=active_sckname,
|
||||
kind=...,
|
||||
unit_system=...,
|
||||
)
|
||||
|
||||
if direc == 'input':
|
||||
# 4. Run all trigger-only `on_value_changed` callbacks.
|
||||
## -> Runs any event methods that relied on the socket.
|
||||
## -> Only methods that don't **require** the socket.
|
||||
|
@ -435,18 +449,33 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
triggered_event_methods = [
|
||||
event_method
|
||||
for event_method in self.filtered_event_methods_by_event(
|
||||
ct.FlowEvent.DataChanged, (bl_socket_name, None, None)
|
||||
ct.FlowEvent.DataChanged, (active_sckname, None, None)
|
||||
)
|
||||
if bl_socket_name
|
||||
if active_sckname
|
||||
not in event_method.callback_info.must_load_sockets
|
||||
]
|
||||
for event_method in triggered_event_methods:
|
||||
event_method(self)
|
||||
|
||||
else:
|
||||
# 3. Invalidate the output socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
## -> Done after socket removal to protect from race condition.
|
||||
self.compute_output.invalidate(
|
||||
input_socket_name=active_sckname,
|
||||
kind=...,
|
||||
)
|
||||
|
||||
# Update Sockets
|
||||
for bl_socket in bl_sockets_to_update:
|
||||
bl_socket_name = bl_socket.name
|
||||
socket_def = active_socket_defs[bl_socket_name]
|
||||
## -> The symptom of using a deleted socket is... hard crash.
|
||||
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
|
||||
## -> The multi-stage for-loop helps us guard from deleted sockets.
|
||||
for active_sckname in bl_sockets_to_update:
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
bl_socket = bl_sockets.get(active_sckname)
|
||||
|
||||
if bl_socket is not None:
|
||||
socket_def = active_socket_defs[active_sckname]
|
||||
|
||||
# 1. Pretend to Initialize for the First Time
|
||||
## -> NOTE: The socket's caches will be completely regenerated.
|
||||
|
@ -456,19 +485,34 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
|
|||
socket_def.init(bl_socket)
|
||||
socket_def.postinit(bl_socket)
|
||||
|
||||
for active_sckname in bl_sockets_to_update:
|
||||
bl_sockets = self._bl_sockets(direc)
|
||||
bl_socket = bl_sockets.get(active_sckname)
|
||||
|
||||
if bl_socket is not None:
|
||||
# 2. Re-Test Socket Capabilities
|
||||
## -> Factors influencing CapabilitiesFlow may have changed.
|
||||
## -> Therefore, we must re-test all link capabilities.
|
||||
bl_socket.remove_invalidated_links()
|
||||
|
||||
if direc == 'input':
|
||||
# 3. Invalidate the input socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
self._compute_input.invalidate(
|
||||
input_socket_name=bl_socket_name,
|
||||
input_socket_name=active_sckname,
|
||||
kind=...,
|
||||
unit_system=...,
|
||||
)
|
||||
|
||||
if direc == 'output':
|
||||
# 3. Invalidate the output socket cache across all kinds.
|
||||
## -> Prevents phantom values from remaining available.
|
||||
## -> Done after socket removal to protect from race condition.
|
||||
self.compute_output.invalidate(
|
||||
input_socket_name=active_sckname,
|
||||
kind=...,
|
||||
)
|
||||
|
||||
def _add_new_active_sockets(self):
|
||||
"""Add and initialize all "active" sockets that aren't on the node.
|
||||
|
||||
|
|
|
@ -104,6 +104,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
||||
interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True))
|
||||
|
||||
preview_value_z: int = bl_cache.BLField(0)
|
||||
preview_value_q: tuple[int, int] = bl_cache.BLField((0, 1))
|
||||
preview_value_re: float = bl_cache.BLField(0.0)
|
||||
preview_value_im: float = bl_cache.BLField(0.0)
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
|
@ -122,6 +127,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
'interval_finite_im',
|
||||
'interval_inf_im',
|
||||
'interval_closed_im',
|
||||
'preview_value_z',
|
||||
'preview_value_q',
|
||||
'preview_value_re',
|
||||
'preview_value_im',
|
||||
}
|
||||
)
|
||||
def symbol(self) -> sim_symbols.SimSymbol:
|
||||
|
@ -140,6 +149,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
interval_finite_im=self.interval_finite_im,
|
||||
interval_inf_im=self.interval_inf_im,
|
||||
interval_closed_im=self.interval_closed_im,
|
||||
preview_value_z=self.preview_value_z,
|
||||
preview_value_q=self.preview_value_q,
|
||||
preview_value_re=self.preview_value_re,
|
||||
preview_value_im=self.preview_value_im,
|
||||
)
|
||||
|
||||
####################
|
||||
|
@ -164,36 +177,68 @@ class SymbolConstantNode(base.MaxwellSimNode):
|
|||
|
||||
col.prop(self, self.blfields['physical_type'], text='')
|
||||
|
||||
# Domain - Infinite
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Domain')
|
||||
row.label(text='Domain - Is Infinite')
|
||||
|
||||
row = col.row(align=True)
|
||||
if self.mathtype is spux.MathType.Complex:
|
||||
row.prop(self, self.blfields['interval_inf'], text='ℝ')
|
||||
row.prop(self, self.blfields['interval_inf_im'], text='𝕀')
|
||||
else:
|
||||
row.prop(self, self.blfields['interval_inf'], text='')
|
||||
|
||||
if any(not b for b in self.interval_inf):
|
||||
# Domain - Closure
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Domain - Closure')
|
||||
|
||||
row = col.row(align=True)
|
||||
if self.mathtype is spux.MathType.Complex:
|
||||
row.prop(self, self.blfields['interval_closed'], text='ℝ')
|
||||
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
|
||||
else:
|
||||
row.prop(self, self.blfields['interval_closed'], text='')
|
||||
|
||||
# Domain - Finite
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Domain - Interval')
|
||||
|
||||
row = col.row(align=True)
|
||||
match self.mathtype:
|
||||
case spux.MathType.Integer:
|
||||
col.prop(self, self.blfields['interval_finite_z'], text='')
|
||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
||||
row.prop(self, self.blfields['interval_finite_z'], text='')
|
||||
|
||||
case spux.MathType.Rational:
|
||||
col.prop(self, self.blfields['interval_finite_q'], text='')
|
||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
||||
row.prop(self, self.blfields['interval_finite_q'], text='')
|
||||
|
||||
case spux.MathType.Real:
|
||||
col.prop(self, self.blfields['interval_finite_re'], text='')
|
||||
col.prop(self, self.blfields['interval_inf'], text='Infinite')
|
||||
col.prop(self, self.blfields['interval_closed'], text='Closed')
|
||||
row.prop(self, self.blfields['interval_finite_re'], text='')
|
||||
|
||||
case spux.MathType.Complex:
|
||||
col.prop(self, self.blfields['interval_finite_re'], text='ℝ')
|
||||
col.prop(self, self.blfields['interval_inf'], text='ℝ Infinite')
|
||||
col.prop(self, self.blfields['interval_closed'], text='ℝ Closed')
|
||||
row.prop(self, self.blfields['interval_finite_re'], text='ℝ')
|
||||
row.prop(self, self.blfields['interval_finite_im'], text='𝕀')
|
||||
|
||||
col.separator()
|
||||
# Domain - Closure
|
||||
row = col.row(align=True)
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Preview Value')
|
||||
match self.mathtype:
|
||||
case spux.MathType.Integer:
|
||||
row.prop(self, self.blfields['preview_value_z'], text='')
|
||||
|
||||
col.prop(self, self.blfields['interval_finite_im'], text='𝕀')
|
||||
col.prop(self, self.blfields['interval_inf'], text='𝕀 Infinite')
|
||||
col.prop(self, self.blfields['interval_closed'], text='𝕀 Closed')
|
||||
case spux.MathType.Rational:
|
||||
row.prop(self, self.blfields['preview_value_q'], text='')
|
||||
|
||||
case spux.MathType.Real:
|
||||
row.prop(self, self.blfields['preview_value_re'], text='')
|
||||
|
||||
case spux.MathType.Complex:
|
||||
row.prop(self, self.blfields['preview_value_re'], text='ℝ')
|
||||
row.prop(self, self.blfields['preview_value_im'], text='𝕀')
|
||||
|
||||
####################
|
||||
# - FlowKinds
|
||||
|
|
|
@ -132,7 +132,7 @@ class SimDomainNode(base.MaxwellSimNode):
|
|||
| grid
|
||||
| medium
|
||||
).compose_within(
|
||||
enclosing_func=lambda els: {
|
||||
lambda els: {
|
||||
'run_time': els[0],
|
||||
'center': tuple(els[1].flatten()),
|
||||
'size': tuple(els[2].flatten()),
|
||||
|
|
|
@ -88,36 +88,19 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'Structure',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
input_sockets={'Medium', 'Center', 'Size'},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_value(self, input_sockets, output_sockets) -> td.Box:
|
||||
"""Compute a single box structure object, given that all inputs are non-symbolic."""
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
medium = input_sockets['Medium']
|
||||
output_params = output_sockets['Structure']
|
||||
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Structure'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Structure'][ct.FlowKind.Params]
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if (
|
||||
has_center
|
||||
and has_size
|
||||
and has_medium
|
||||
and has_output_params
|
||||
and not output_params.symbols
|
||||
):
|
||||
return td.Structure(
|
||||
geometry=td.Box(
|
||||
center=spux.scale_to_unit_system(center, ct.UNITS_TIDY3D),
|
||||
size=spux.scale_to_unit_system(size, ct.UNITS_TIDY3D),
|
||||
),
|
||||
medium=medium,
|
||||
)
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params, disallow_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
|
@ -134,45 +117,43 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'Center': ct.FlowKind.Func,
|
||||
'Size': ct.FlowKind.Func,
|
||||
},
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
|
||||
def compute_structure_func(self, props, input_sockets) -> td.Box:
|
||||
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
|
||||
output_params = output_sockets['Structure']
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
medium = input_sockets['Medium']
|
||||
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_output_params and has_center and has_size and has_medium:
|
||||
if has_center and has_size and has_medium:
|
||||
differentiable = props['differentiable']
|
||||
if differentiable:
|
||||
return (center | size | medium).compose_within(
|
||||
enclosing_func=lambda els: tdadj.JaxStructure(
|
||||
return (
|
||||
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| medium
|
||||
).compose_within(
|
||||
lambda els: tdadj.JaxStructure(
|
||||
geometry=tdadj.JaxBox(
|
||||
center=tuple(els[0].flatten()),
|
||||
size=tuple(els[1].flatten()),
|
||||
center_jax=tuple(els[0].flatten()),
|
||||
size_jax=tuple(els[1].flatten()),
|
||||
),
|
||||
medium=els[2],
|
||||
),
|
||||
supports_jax=True,
|
||||
)
|
||||
return (center | size | medium).compose_within(
|
||||
## TODO: Unit conversion within the composed function??
|
||||
## -- We do need Tidy3D to be given ex. micrometers in particular.
|
||||
## -- But the previous numerical output might not be micrometers.
|
||||
## -- There must be a way to add a conversion in, without strangeness.
|
||||
## -- Ex. can compose_within() take a unit system?
|
||||
## -- This would require
|
||||
enclosing_func=lambda els: td.Structure(
|
||||
return (
|
||||
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| medium
|
||||
).compose_within(
|
||||
lambda els: td.Structure(
|
||||
geometry=td.Box(
|
||||
center=tuple(els[0].flatten()),
|
||||
size=tuple(els[1].flatten()),
|
||||
center=els[0].flatten().tolist(),
|
||||
size=els[1].flatten().tolist(),
|
||||
),
|
||||
medium=els[2],
|
||||
),
|
||||
|
@ -187,7 +168,6 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'Structure',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
props={'differentiable'},
|
||||
input_sockets={'Medium', 'Center', 'Size'},
|
||||
input_socket_kinds={
|
||||
'Medium': ct.FlowKind.Params,
|
||||
|
@ -195,7 +175,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'Size': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_params(self, props, input_sockets) -> td.Box:
|
||||
def compute_params(self, input_sockets) -> td.Box:
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
medium = input_sockets['Medium']
|
||||
|
@ -238,13 +218,16 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': ct.FlowKind.Params},
|
||||
)
|
||||
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
|
||||
output_params = output_sockets['Structure']
|
||||
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
|
||||
center = input_sockets['Center']
|
||||
size = input_sockets['Size']
|
||||
output_params = output_sockets['Structure']
|
||||
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
if has_center and has_output_params and not output_params.symbols:
|
||||
has_size = not ct.FlowSignal.check(size)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_center and has_size and has_output_params and not output_params.symbols:
|
||||
## TODO: There are strategies for handling examples of symbol values.
|
||||
|
||||
# Push Loose Input Values to GeoNodes Modifier
|
||||
|
@ -254,7 +237,7 @@ class BoxStructureNode(base.MaxwellSimNode):
|
|||
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
|
||||
'unit_system': ct.UNITS_BLENDER,
|
||||
'inputs': {
|
||||
'Size': input_sockets['Size'],
|
||||
'Size': size,
|
||||
},
|
||||
},
|
||||
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),
|
||||
|
|
|
@ -32,6 +32,8 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
class CylinderStructureNode(base.MaxwellSimNode):
|
||||
"""A generic cylinder structure with configurable radius and height."""
|
||||
|
||||
node_type = ct.NodeType.CylinderStructure
|
||||
bl_label = 'Cylinder Structure'
|
||||
use_sim_node_name = True
|
||||
|
@ -64,27 +66,101 @@ class CylinderStructureNode(base.MaxwellSimNode):
|
|||
}
|
||||
|
||||
####################
|
||||
# - Output Socket Computation
|
||||
# - FlowKind.Value
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Structure',
|
||||
kind=ct.FlowKind.Value,
|
||||
# Loaded
|
||||
output_sockets={'Structure'},
|
||||
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
|
||||
)
|
||||
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
|
||||
output_func = output_sockets['Structure'][ct.FlowKind.Func]
|
||||
output_params = output_sockets['Structure'][ct.FlowKind.Params]
|
||||
|
||||
has_output_func = not ct.FlowSignal.check(output_func)
|
||||
has_output_params = not ct.FlowSignal.check(output_params)
|
||||
|
||||
if has_output_func and has_output_params and not output_params.symbols:
|
||||
return output_func.realize(output_params, disallow_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Func
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Structure',
|
||||
kind=ct.FlowKind.Func,
|
||||
# Loaded
|
||||
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
|
||||
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
|
||||
scale_input_sockets={
|
||||
'Center': 'Tidy3DUnits',
|
||||
'Radius': 'Tidy3DUnits',
|
||||
'Height': 'Tidy3DUnits',
|
||||
input_socket_kinds={
|
||||
'Center': ct.FlowKind.Func,
|
||||
'Radius': ct.FlowKind.Func,
|
||||
'Height': ct.FlowKind.Func,
|
||||
'Medium': ct.FlowKind.Func,
|
||||
},
|
||||
)
|
||||
def compute_structure(self, input_sockets, unit_systems) -> td.Box:
|
||||
return td.Structure(
|
||||
def compute_func(self, input_sockets) -> td.Box:
|
||||
"""Compute a single cylinder structure object, given that all inputs are non-symbolic."""
|
||||
center = input_sockets['Center']
|
||||
radius = input_sockets['Radius']
|
||||
height = input_sockets['Height']
|
||||
medium = input_sockets['Medium']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_radius = not ct.FlowSignal.check(radius)
|
||||
has_height = not ct.FlowSignal.check(height)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_center and has_radius and has_height and has_medium:
|
||||
return (
|
||||
center.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| radius.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| height.scale_to_unit_system(ct.UNITS_TIDY3D)
|
||||
| medium
|
||||
).compose_within(
|
||||
lambda els: td.Structure(
|
||||
geometry=td.Cylinder(
|
||||
radius=input_sockets['Radius'],
|
||||
center=input_sockets['Center'],
|
||||
length=input_sockets['Height'],
|
||||
center=els[0].flatten().tolist(),
|
||||
radius=els[1],
|
||||
length=els[2],
|
||||
),
|
||||
medium=input_sockets['Medium'],
|
||||
medium=els[3],
|
||||
)
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Structure',
|
||||
kind=ct.FlowKind.Params,
|
||||
# Loaded
|
||||
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
|
||||
input_socket_kinds={
|
||||
'Center': ct.FlowKind.Params,
|
||||
'Radius': ct.FlowKind.Params,
|
||||
'Height': ct.FlowKind.Params,
|
||||
'Medium': ct.FlowKind.Params,
|
||||
},
|
||||
)
|
||||
def compute_params(self, input_sockets) -> td.Box:
|
||||
center = input_sockets['Center']
|
||||
radius = input_sockets['Radius']
|
||||
height = input_sockets['Height']
|
||||
medium = input_sockets['Medium']
|
||||
|
||||
has_center = not ct.FlowSignal.check(center)
|
||||
has_radius = not ct.FlowSignal.check(radius)
|
||||
has_height = not ct.FlowSignal.check(height)
|
||||
has_medium = not ct.FlowSignal.check(medium)
|
||||
|
||||
if has_center and has_radius and has_height and has_medium:
|
||||
return center | radius | height | medium
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Preview
|
||||
|
|
|
@ -791,14 +791,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
####################
|
||||
# - FlowKind: Func (w/Params if Constant)
|
||||
####################
|
||||
@bl_cache.cached_bl_property(
|
||||
depends_on={
|
||||
'value',
|
||||
'sorted_sp_symbols',
|
||||
'sorted_symbols',
|
||||
'output_sym',
|
||||
}
|
||||
)
|
||||
@bl_cache.cached_bl_property(depends_on={'output_sym'})
|
||||
def lazy_func(self) -> ct.FuncFlow:
|
||||
"""Returns a lazy value that computes the expression returned by `self.value`.
|
||||
|
||||
|
@ -806,36 +799,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
|
||||
"""
|
||||
if self.output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||
):
|
||||
return ct.FuncFlow(
|
||||
func=sp.lambdify(
|
||||
self.sorted_sp_symbols,
|
||||
self.output_sym.conform(self.value, strip_unit=True),
|
||||
'jax',
|
||||
),
|
||||
func_args=list(self.sorted_symbols),
|
||||
func_output=self.output_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
|
||||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[self.output_sym],
|
||||
func_output=self.output_sym,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
case ct.FlowKind.Range if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.FuncFlow(
|
||||
func=lambda v: v,
|
||||
func_args=[self.output_sym],
|
||||
|
@ -854,22 +817,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
|
||||
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
|
||||
"""
|
||||
output_sym = self.output_sym
|
||||
if output_sym is not None:
|
||||
if self.output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=list(self.sorted_symbols),
|
||||
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
|
||||
symbols=set(self.sorted_symbols),
|
||||
)
|
||||
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
|
||||
not ct.FlowSignal.check(self.value)
|
||||
):
|
||||
return ct.ParamsFlow(
|
||||
arg_targets=[self.output_sym],
|
||||
func_args=[self.value],
|
||||
symbols=set(self.sorted_symbols),
|
||||
)
|
||||
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
|
@ -899,20 +855,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
|
||||
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
|
||||
"""
|
||||
output_sym = self.output_sym
|
||||
if output_sym is not None:
|
||||
if self.output_sym is not None:
|
||||
match self.active_kind:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func:
|
||||
return ct.InfoFlow(
|
||||
dims={sym: None for sym in self.sorted_symbols},
|
||||
output=self.output_sym,
|
||||
)
|
||||
|
||||
case ct.FlowKind.Value | ct.FlowKind.Func if (
|
||||
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
|
||||
):
|
||||
return ct.InfoFlow(output=self.output_sym)
|
||||
|
||||
case ct.FlowKind.Range if self.sorted_symbols:
|
||||
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
|
||||
raise NotImplementedError(msg)
|
||||
|
|
|
@ -14,7 +14,10 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import math
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import scipy as sc
|
||||
import sympy.physics.units as spu
|
||||
import tidy3d as td
|
||||
|
@ -28,9 +31,15 @@ from .. import base
|
|||
|
||||
log = logger.get(__name__)
|
||||
|
||||
VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
|
||||
|
||||
FIXED_WL = 500 * spu.nm
|
||||
_VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
|
||||
_FIXED_WL = 500 * spu.nm
|
||||
FIXED_FREQ = spux.scale_to_unit_system(
|
||||
spu.convert_to(
|
||||
_VAC_SPEED_OF_LIGHT / _FIXED_WL,
|
||||
spu.hertz,
|
||||
),
|
||||
ct.UNITS_TIDY3D,
|
||||
)
|
||||
|
||||
|
||||
class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
||||
|
@ -49,24 +58,17 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
|
|||
####################
|
||||
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
|
||||
def value(self) -> td.Medium:
|
||||
freq = (
|
||||
spu.convert_to(
|
||||
VAC_SPEED_OF_LIGHT / FIXED_WL,
|
||||
spu.hertz,
|
||||
)
|
||||
/ spu.hertz
|
||||
)
|
||||
eps_r_re = self.eps_rel[0]
|
||||
conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
|
||||
|
||||
if self.differentiable:
|
||||
return tdadj.JaxMedium.from_nk(
|
||||
n=self.eps_rel[0],
|
||||
k=self.eps_rel[1],
|
||||
freq=freq,
|
||||
return tdadj.JaxMedium(
|
||||
permittivity_jax=jnp.array(eps_r_re, dtype=float),
|
||||
conductivity_jax=jnp.array(conductivity, dtype=float),
|
||||
)
|
||||
return td.Medium.from_nk(
|
||||
n=self.eps_rel[0],
|
||||
k=self.eps_rel[1],
|
||||
freq=freq,
|
||||
return td.Medium(
|
||||
permittivity=eps_r_re,
|
||||
conductivity=conductivity,
|
||||
)
|
||||
|
||||
@value.setter
|
||||
|
|
|
@ -263,6 +263,11 @@ class SimSymbol(pyd.BaseModel):
|
|||
interval_inf_im: tuple[bool, bool] = (True, True)
|
||||
interval_closed_im: tuple[bool, bool] = (False, False)
|
||||
|
||||
preview_value_z: int = 0
|
||||
preview_value_q: tuple[int, int] = (0, 1)
|
||||
preview_value_re: float = 0.0
|
||||
preview_value_im: float = 0.0
|
||||
|
||||
####################
|
||||
# - Core
|
||||
####################
|
||||
|
|
Loading…
Reference in New Issue