fix: use-after-free in socket pruner

Also the usual batch of improvements.
Differentiability is misbehaving, intriguingly.
main
Sofus Albert Høgsbro Rose 2024-05-30 21:39:48 +02:00
parent 38e70a60d3
commit c286d65193
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
9 changed files with 328 additions and 215 deletions

View File

@ -325,8 +325,16 @@ class FuncFlow(pyd.BaseModel):
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{} {}
), ),
disallow_jax: bool = True,
) -> typ.Self: ) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.""" """Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
Parameters:
params: The parameter-tracking object from which function arguments will be computed.
symbol_values: Values for all `SimSymbol`s that are not yet realized in `params`.
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
"""
if self.supports_jax: if self.supports_jax:
return self.func_jax( return self.func_jax(
*params.scaled_func_args(symbol_values), *params.scaled_func_args(symbol_values),

View File

@ -374,19 +374,22 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
""" """
node_tree = self.id_data node_tree = self.id_data
for direc in ['input', 'output']: for direc in ['input', 'output']:
bl_sockets = self._bl_sockets(direc) active_socket_nametype = {
bl_socket.name: bl_socket.socket_type
for bl_socket in self._bl_sockets(direc)
}
active_socket_defs = self.active_socket_defs(direc) active_socket_defs = self.active_socket_defs(direc)
# Determine Sockets to Remove # Determine Sockets to Remove
## -> Name: If the existing socket name isn't "active". ## -> Name: If the existing socket name isn't "active".
## -> Type: If the existing socket_type != "active" SocketDef. ## -> Type: If the existing socket_type != "active" SocketDef.
bl_sockets_to_remove = [ bl_sockets_to_remove = [
bl_socket active_sckname
for socket_name, bl_socket in bl_sockets.items() for active_sckname, active_scktype in active_socket_nametype.items()
if ( if (
socket_name not in active_socket_defs active_sckname not in active_socket_defs
or bl_socket.socket_type or active_scktype
is not active_socket_defs[socket_name].socket_type is not active_socket_defs[active_sckname].socket_type
) )
] ]
@ -394,39 +397,50 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Name: If the existing socket name is "active". ## -> Name: If the existing socket name is "active".
## -> Type: If the existing socket_type == "active" SocketDef. ## -> Type: If the existing socket_type == "active" SocketDef.
## -> Compare: If the existing socket differs from the SocketDef. ## -> Compare: If the existing socket differs from the SocketDef.
## -> NOTE: Reload bl_sockets in case to-update scks were removed.
bl_sockets_to_update = [ bl_sockets_to_update = [
bl_socket active_sckname
for socket_name, bl_socket in bl_sockets.items() for active_sckname, active_scktype in active_socket_nametype.items()
if ( if (
socket_name in active_socket_defs active_sckname in active_socket_defs
and bl_socket.socket_type and active_scktype is active_socket_defs[active_sckname].socket_type
is active_socket_defs[socket_name].socket_type and not active_socket_defs[active_sckname].compare(
and not active_socket_defs[socket_name].compare(bl_socket) self._bl_sockets(direc)[active_sckname]
)
) )
] ]
# Remove Sockets # Remove Sockets
for bl_socket in bl_sockets_to_remove: ## -> The symptom of using a deleted socket is... hard crash.
bl_socket_name = bl_socket.name ## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
## -> The multi-stage for-loop helps us guard from deleted sockets.
for active_sckname in bl_sockets_to_remove:
bl_socket = self._bl_sockets(direc).get(active_sckname)
# 1. Report the socket removal to the NodeTree. # 1. Report the socket removal to the NodeTree.
## -> The NodeLinkCache needs to be adjusted manually. ## -> The NodeLinkCache needs to be adjusted manually.
node_tree.on_node_socket_removed(bl_socket) node_tree.on_node_socket_removed(bl_socket)
for active_sckname in bl_sockets_to_remove:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
# 2. Perform the removal using Blender's API. # 2. Perform the removal using Blender's API.
## -> Actually removes the socket. ## -> Actually removes the socket.
## -> Must be protected from auto-removed use-after-free.
if bl_socket is not None:
bl_sockets.remove(bl_socket) bl_sockets.remove(bl_socket)
if direc == 'input':
# 3. Invalidate the input socket cache across all kinds. # 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available. ## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition. ## -> Done after socket removal to protect from race condition.
self._compute_input.invalidate( self._compute_input.invalidate(
input_socket_name=bl_socket_name, input_socket_name=active_sckname,
kind=..., kind=...,
unit_system=..., unit_system=...,
) )
if direc == 'input':
# 4. Run all trigger-only `on_value_changed` callbacks. # 4. Run all trigger-only `on_value_changed` callbacks.
## -> Runs any event methods that relied on the socket. ## -> Runs any event methods that relied on the socket.
## -> Only methods that don't **require** the socket. ## -> Only methods that don't **require** the socket.
@ -435,18 +449,33 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
triggered_event_methods = [ triggered_event_methods = [
event_method event_method
for event_method in self.filtered_event_methods_by_event( for event_method in self.filtered_event_methods_by_event(
ct.FlowEvent.DataChanged, (bl_socket_name, None, None) ct.FlowEvent.DataChanged, (active_sckname, None, None)
) )
if bl_socket_name if active_sckname
not in event_method.callback_info.must_load_sockets not in event_method.callback_info.must_load_sockets
] ]
for event_method in triggered_event_methods: for event_method in triggered_event_methods:
event_method(self) event_method(self)
else:
# 3. Invalidate the output socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
)
# Update Sockets # Update Sockets
for bl_socket in bl_sockets_to_update: ## -> The symptom of using a deleted socket is... hard crash.
bl_socket_name = bl_socket.name ## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
socket_def = active_socket_defs[bl_socket_name] ## -> The multi-stage for-loop helps us guard from deleted sockets.
for active_sckname in bl_sockets_to_update:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
if bl_socket is not None:
socket_def = active_socket_defs[active_sckname]
# 1. Pretend to Initialize for the First Time # 1. Pretend to Initialize for the First Time
## -> NOTE: The socket's caches will be completely regenerated. ## -> NOTE: The socket's caches will be completely regenerated.
@ -456,19 +485,34 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
socket_def.init(bl_socket) socket_def.init(bl_socket)
socket_def.postinit(bl_socket) socket_def.postinit(bl_socket)
for active_sckname in bl_sockets_to_update:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
if bl_socket is not None:
# 2. Re-Test Socket Capabilities # 2. Re-Test Socket Capabilities
## -> Factors influencing CapabilitiesFlow may have changed. ## -> Factors influencing CapabilitiesFlow may have changed.
## -> Therefore, we must re-test all link capabilities. ## -> Therefore, we must re-test all link capabilities.
bl_socket.remove_invalidated_links() bl_socket.remove_invalidated_links()
if direc == 'input':
# 3. Invalidate the input socket cache across all kinds. # 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available. ## -> Prevents phantom values from remaining available.
self._compute_input.invalidate( self._compute_input.invalidate(
input_socket_name=bl_socket_name, input_socket_name=active_sckname,
kind=..., kind=...,
unit_system=..., unit_system=...,
) )
if direc == 'output':
# 3. Invalidate the output socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
)
def _add_new_active_sockets(self): def _add_new_active_sockets(self):
"""Add and initialize all "active" sockets that aren't on the node. """Add and initialize all "active" sockets that aren't on the node.

View File

@ -104,6 +104,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True)) interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True))
interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True)) interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True))
preview_value_z: int = bl_cache.BLField(0)
preview_value_q: tuple[int, int] = bl_cache.BLField((0, 1))
preview_value_re: float = bl_cache.BLField(0.0)
preview_value_im: float = bl_cache.BLField(0.0)
#################### ####################
# - Computed Properties # - Computed Properties
#################### ####################
@ -122,6 +127,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
'interval_finite_im', 'interval_finite_im',
'interval_inf_im', 'interval_inf_im',
'interval_closed_im', 'interval_closed_im',
'preview_value_z',
'preview_value_q',
'preview_value_re',
'preview_value_im',
} }
) )
def symbol(self) -> sim_symbols.SimSymbol: def symbol(self) -> sim_symbols.SimSymbol:
@ -140,6 +149,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
interval_finite_im=self.interval_finite_im, interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im, interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im, interval_closed_im=self.interval_closed_im,
preview_value_z=self.preview_value_z,
preview_value_q=self.preview_value_q,
preview_value_re=self.preview_value_re,
preview_value_im=self.preview_value_im,
) )
#################### ####################
@ -164,36 +177,68 @@ class SymbolConstantNode(base.MaxwellSimNode):
col.prop(self, self.blfields['physical_type'], text='') col.prop(self, self.blfields['physical_type'], text='')
# Domain - Infinite
row = col.row(align=True) row = col.row(align=True)
row.alignment = 'CENTER' row.alignment = 'CENTER'
row.label(text='Domain') row.label(text='Domain - Is Infinite')
row = col.row(align=True)
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_inf'], text='')
row.prop(self, self.blfields['interval_inf_im'], text='𝕀')
else:
row.prop(self, self.blfields['interval_inf'], text='')
if any(not b for b in self.interval_inf):
# Domain - Closure
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Domain - Closure')
row = col.row(align=True)
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_closed'], text='')
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
else:
row.prop(self, self.blfields['interval_closed'], text='')
# Domain - Finite
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Domain - Interval')
row = col.row(align=True)
match self.mathtype: match self.mathtype:
case spux.MathType.Integer: case spux.MathType.Integer:
col.prop(self, self.blfields['interval_finite_z'], text='') row.prop(self, self.blfields['interval_finite_z'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
case spux.MathType.Rational: case spux.MathType.Rational:
col.prop(self, self.blfields['interval_finite_q'], text='') row.prop(self, self.blfields['interval_finite_q'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
case spux.MathType.Real: case spux.MathType.Real:
col.prop(self, self.blfields['interval_finite_re'], text='') row.prop(self, self.blfields['interval_finite_re'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
case spux.MathType.Complex: case spux.MathType.Complex:
col.prop(self, self.blfields['interval_finite_re'], text='') row.prop(self, self.blfields['interval_finite_re'], text='')
col.prop(self, self.blfields['interval_inf'], text=' Infinite') row.prop(self, self.blfields['interval_finite_im'], text='𝕀')
col.prop(self, self.blfields['interval_closed'], text=' Closed')
col.separator() # Domain - Closure
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Preview Value')
match self.mathtype:
case spux.MathType.Integer:
row.prop(self, self.blfields['preview_value_z'], text='')
col.prop(self, self.blfields['interval_finite_im'], text='𝕀') case spux.MathType.Rational:
col.prop(self, self.blfields['interval_inf'], text='𝕀 Infinite') row.prop(self, self.blfields['preview_value_q'], text='')
col.prop(self, self.blfields['interval_closed'], text='𝕀 Closed')
case spux.MathType.Real:
row.prop(self, self.blfields['preview_value_re'], text='')
case spux.MathType.Complex:
row.prop(self, self.blfields['preview_value_re'], text='')
row.prop(self, self.blfields['preview_value_im'], text='𝕀')
#################### ####################
# - FlowKinds # - FlowKinds

View File

@ -132,7 +132,7 @@ class SimDomainNode(base.MaxwellSimNode):
| grid | grid
| medium | medium
).compose_within( ).compose_within(
enclosing_func=lambda els: { lambda els: {
'run_time': els[0], 'run_time': els[0],
'center': tuple(els[1].flatten()), 'center': tuple(els[1].flatten()),
'size': tuple(els[2].flatten()), 'size': tuple(els[2].flatten()),

View File

@ -88,36 +88,19 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure', 'Structure',
kind=ct.FlowKind.Value, kind=ct.FlowKind.Value,
# Loaded # Loaded
input_sockets={'Medium', 'Center', 'Size'},
output_sockets={'Structure'}, output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params}, output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
) )
def compute_value(self, input_sockets, output_sockets) -> td.Box: def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute a single box structure object, given that all inputs are non-symbolic.""" """Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
center = input_sockets['Center'] output_func = output_sockets['Structure'][ct.FlowKind.Func]
size = input_sockets['Size'] output_params = output_sockets['Structure'][ct.FlowKind.Params]
medium = input_sockets['Medium']
output_params = output_sockets['Structure']
has_center = not ct.FlowSignal.check(center) has_output_func = not ct.FlowSignal.check(output_func)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
has_output_params = not ct.FlowSignal.check(output_params) has_output_params = not ct.FlowSignal.check(output_params)
if ( if has_output_func and has_output_params and not output_params.symbols:
has_center return output_func.realize(output_params, disallow_jax=True)
and has_size
and has_medium
and has_output_params
and not output_params.symbols
):
return td.Structure(
geometry=td.Box(
center=spux.scale_to_unit_system(center, ct.UNITS_TIDY3D),
size=spux.scale_to_unit_system(size, ct.UNITS_TIDY3D),
),
medium=medium,
)
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
#################### ####################
@ -134,45 +117,43 @@ class BoxStructureNode(base.MaxwellSimNode):
'Center': ct.FlowKind.Func, 'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func, 'Size': ct.FlowKind.Func,
}, },
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
) )
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box: def compute_structure_func(self, props, input_sockets) -> td.Box:
"""Compute a possibly-differentiable function, producing a box structure from the input parameters.""" """Compute a possibly-differentiable function, producing a box structure from the input parameters."""
output_params = output_sockets['Structure']
center = input_sockets['Center'] center = input_sockets['Center']
size = input_sockets['Size'] size = input_sockets['Size']
medium = input_sockets['Medium'] medium = input_sockets['Medium']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center) has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size) has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium) has_medium = not ct.FlowSignal.check(medium)
if has_output_params and has_center and has_size and has_medium: if has_center and has_size and has_medium:
differentiable = props['differentiable'] differentiable = props['differentiable']
if differentiable: if differentiable:
return (center | size | medium).compose_within( return (
enclosing_func=lambda els: tdadj.JaxStructure( center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: tdadj.JaxStructure(
geometry=tdadj.JaxBox( geometry=tdadj.JaxBox(
center=tuple(els[0].flatten()), center_jax=tuple(els[0].flatten()),
size=tuple(els[1].flatten()), size_jax=tuple(els[1].flatten()),
), ),
medium=els[2], medium=els[2],
), ),
supports_jax=True, supports_jax=True,
) )
return (center | size | medium).compose_within( return (
## TODO: Unit conversion within the composed function?? center.scale_to_unit_system(ct.UNITS_TIDY3D)
## -- We do need Tidy3D to be given ex. micrometers in particular. | size.scale_to_unit_system(ct.UNITS_TIDY3D)
## -- But the previous numerical output might not be micrometers. | medium
## -- There must be a way to add a conversion in, without strangeness. ).compose_within(
## -- Ex. can compose_within() take a unit system? lambda els: td.Structure(
## -- This would require
enclosing_func=lambda els: td.Structure(
geometry=td.Box( geometry=td.Box(
center=tuple(els[0].flatten()), center=els[0].flatten().tolist(),
size=tuple(els[1].flatten()), size=els[1].flatten().tolist(),
), ),
medium=els[2], medium=els[2],
), ),
@ -187,7 +168,6 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure', 'Structure',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
# Loaded # Loaded
props={'differentiable'},
input_sockets={'Medium', 'Center', 'Size'}, input_sockets={'Medium', 'Center', 'Size'},
input_socket_kinds={ input_socket_kinds={
'Medium': ct.FlowKind.Params, 'Medium': ct.FlowKind.Params,
@ -195,7 +175,7 @@ class BoxStructureNode(base.MaxwellSimNode):
'Size': ct.FlowKind.Params, 'Size': ct.FlowKind.Params,
}, },
) )
def compute_params(self, props, input_sockets) -> td.Box: def compute_params(self, input_sockets) -> td.Box:
center = input_sockets['Center'] center = input_sockets['Center']
size = input_sockets['Size'] size = input_sockets['Size']
medium = input_sockets['Medium'] medium = input_sockets['Medium']
@ -238,13 +218,16 @@ class BoxStructureNode(base.MaxwellSimNode):
output_sockets={'Structure'}, output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params}, output_socket_kinds={'Structure': ct.FlowKind.Params},
) )
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets): def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
output_params = output_sockets['Structure']
center = input_sockets['Center'] center = input_sockets['Center']
size = input_sockets['Size']
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center) has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols: has_size = not ct.FlowSignal.check(size)
has_output_params = not ct.FlowSignal.check(output_params)
if has_center and has_size and has_output_params and not output_params.symbols:
## TODO: There are strategies for handling examples of symbol values. ## TODO: There are strategies for handling examples of symbol values.
# Push Loose Input Values to GeoNodes Modifier # Push Loose Input Values to GeoNodes Modifier
@ -254,7 +237,7 @@ class BoxStructureNode(base.MaxwellSimNode):
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox), 'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
'unit_system': ct.UNITS_BLENDER, 'unit_system': ct.UNITS_BLENDER,
'inputs': { 'inputs': {
'Size': input_sockets['Size'], 'Size': size,
}, },
}, },
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER), location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),

View File

@ -32,6 +32,8 @@ log = logger.get(__name__)
class CylinderStructureNode(base.MaxwellSimNode): class CylinderStructureNode(base.MaxwellSimNode):
"""A generic cylinder structure with configurable radius and height."""
node_type = ct.NodeType.CylinderStructure node_type = ct.NodeType.CylinderStructure
bl_label = 'Cylinder Structure' bl_label = 'Cylinder Structure'
use_sim_node_name = True use_sim_node_name = True
@ -64,27 +66,101 @@ class CylinderStructureNode(base.MaxwellSimNode):
} }
#################### ####################
# - Output Socket Computation # - FlowKind.Value
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Structure', 'Structure',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Structure'},
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][ct.FlowKind.Func]
output_params = output_sockets['Structure'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Func,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'}, input_sockets={'Center', 'Radius', 'Medium', 'Height'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, input_socket_kinds={
scale_input_sockets={ 'Center': ct.FlowKind.Func,
'Center': 'Tidy3DUnits', 'Radius': ct.FlowKind.Func,
'Radius': 'Tidy3DUnits', 'Height': ct.FlowKind.Func,
'Height': 'Tidy3DUnits', 'Medium': ct.FlowKind.Func,
}, },
) )
def compute_structure(self, input_sockets, unit_systems) -> td.Box: def compute_func(self, input_sockets) -> td.Box:
return td.Structure( """Compute a single cylinder structure object, given that all inputs are non-symbolic."""
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| radius.scale_to_unit_system(ct.UNITS_TIDY3D)
| height.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: td.Structure(
geometry=td.Cylinder( geometry=td.Cylinder(
radius=input_sockets['Radius'], center=els[0].flatten().tolist(),
center=input_sockets['Center'], radius=els[1],
length=input_sockets['Height'], length=els[2],
), ),
medium=input_sockets['Medium'], medium=els[3],
) )
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
input_socket_kinds={
'Center': ct.FlowKind.Params,
'Radius': ct.FlowKind.Params,
'Height': ct.FlowKind.Params,
'Medium': ct.FlowKind.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return center | radius | height | medium
return ct.FlowSignal.FlowPending
#################### ####################
# - Preview # - Preview

View File

@ -791,14 +791,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
#################### ####################
# - FlowKind: Func (w/Params if Constant) # - FlowKind: Func (w/Params if Constant)
#################### ####################
@bl_cache.cached_bl_property( @bl_cache.cached_bl_property(depends_on={'output_sym'})
depends_on={
'value',
'sorted_sp_symbols',
'sorted_symbols',
'output_sym',
}
)
def lazy_func(self) -> ct.FuncFlow: def lazy_func(self) -> ct.FuncFlow:
"""Returns a lazy value that computes the expression returned by `self.value`. """Returns a lazy value that computes the expression returned by `self.value`.
@ -806,36 +799,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`. Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
""" """
if self.output_sym is not None: if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if (
self.sorted_symbols and not ct.FlowSignal.check(self.value)
):
return ct.FuncFlow(
func=sp.lambdify(
self.sorted_sp_symbols,
self.output_sym.conform(self.value, strip_unit=True),
'jax',
),
func_args=list(self.sorted_symbols),
func_output=self.output_sym,
supports_jax=True,
)
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.FuncFlow( return ct.FuncFlow(
func=lambda v: v, func=lambda v: v,
func_args=[self.output_sym], func_args=[self.output_sym],
@ -854,22 +817,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use). If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`. Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
""" """
output_sym = self.output_sym if self.output_sym is not None:
if output_sym is not None:
match self.active_kind: match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.ParamsFlow(
arg_targets=list(self.sorted_symbols),
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
symbols=set(self.sorted_symbols),
)
case ct.FlowKind.Value | ct.FlowKind.Func if ( case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.value) not ct.FlowSignal.check(self.value)
): ):
return ct.ParamsFlow( return ct.ParamsFlow(
arg_targets=[self.output_sym], arg_targets=[self.output_sym],
func_args=[self.value], func_args=[self.value],
symbols=set(self.sorted_symbols),
) )
case ct.FlowKind.Range if self.sorted_symbols: case ct.FlowKind.Range if self.sorted_symbols:
@ -899,20 +855,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
""" """
output_sym = self.output_sym if self.output_sym is not None:
if output_sym is not None:
match self.active_kind: match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols: case ct.FlowKind.Value | ct.FlowKind.Func:
return ct.InfoFlow( return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols}, dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym, output=self.output_sym,
) )
case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(output=self.output_sym)
case ct.FlowKind.Range if self.sorted_symbols: case ct.FlowKind.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty' msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -14,7 +14,10 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import math
import bpy import bpy
import jax.numpy as jnp
import scipy as sc import scipy as sc
import sympy.physics.units as spu import sympy.physics.units as spu
import tidy3d as td import tidy3d as td
@ -28,9 +31,15 @@ from .. import base
log = logger.get(__name__) log = logger.get(__name__)
VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second _VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
_FIXED_WL = 500 * spu.nm
FIXED_WL = 500 * spu.nm FIXED_FREQ = spux.scale_to_unit_system(
spu.convert_to(
_VAC_SPEED_OF_LIGHT / _FIXED_WL,
spu.hertz,
),
ct.UNITS_TIDY3D,
)
class MaxwellMediumBLSocket(base.MaxwellSimSocket): class MaxwellMediumBLSocket(base.MaxwellSimSocket):
@ -49,24 +58,17 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
#################### ####################
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'}) @bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
def value(self) -> td.Medium: def value(self) -> td.Medium:
freq = ( eps_r_re = self.eps_rel[0]
spu.convert_to( conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
VAC_SPEED_OF_LIGHT / FIXED_WL,
spu.hertz,
)
/ spu.hertz
)
if self.differentiable: if self.differentiable:
return tdadj.JaxMedium.from_nk( return tdadj.JaxMedium(
n=self.eps_rel[0], permittivity_jax=jnp.array(eps_r_re, dtype=float),
k=self.eps_rel[1], conductivity_jax=jnp.array(conductivity, dtype=float),
freq=freq,
) )
return td.Medium.from_nk( return td.Medium(
n=self.eps_rel[0], permittivity=eps_r_re,
k=self.eps_rel[1], conductivity=conductivity,
freq=freq,
) )
@value.setter @value.setter

View File

@ -263,6 +263,11 @@ class SimSymbol(pyd.BaseModel):
interval_inf_im: tuple[bool, bool] = (True, True) interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False) interval_closed_im: tuple[bool, bool] = (False, False)
preview_value_z: int = 0
preview_value_q: tuple[int, int] = (0, 1)
preview_value_re: float = 0.0
preview_value_im: float = 0.0
#################### ####################
# - Core # - Core
#################### ####################