From c286d65193be7f664793e894208431a52b00a66b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Thu, 30 May 2024 21:39:48 +0200 Subject: [PATCH] fix: use-after-free in socket pruner Also the usual batch of improvements. Differentiability is misbehaving, intriguingly. --- .../contracts/flow_kinds/lazy_func.py | 10 +- .../maxwell_sim_nodes/nodes/base.py | 138 ++++++++++++------ .../nodes/inputs/constants/symbol_constant.py | 83 ++++++++--- .../nodes/simulations/sim_domain.py | 2 +- .../structures/primitives/box_structure.py | 87 +++++------ .../primitives/cylinder_structure.py | 106 ++++++++++++-- .../maxwell_sim_nodes/sockets/expr.py | 74 ++-------- .../sockets/maxwell/medium.py | 38 ++--- src/blender_maxwell/utils/sim_symbols.py | 5 + 9 files changed, 328 insertions(+), 215 deletions(-) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py index c8b32f0..22c9d25 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds/lazy_func.py @@ -325,8 +325,16 @@ class FuncFlow(pyd.BaseModel): symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType( {} ), + disallow_jax: bool = True, ) -> typ.Self: - """Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.""" + """Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols. + + Parameters: + params: The parameter-tracking object from which function arguments will be computed. + symbol_values: Values for all `SimSymbol`s that are not yet realized in `params`. + disallow_jax: Don't use `self.func_jax` to evaluate, even if possible. + This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits. + """ if self.supports_jax: return self.func_jax( *params.scaled_func_args(symbol_values), diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py index 538cbe0..764f62d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/base.py @@ -374,19 +374,22 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): """ node_tree = self.id_data for direc in ['input', 'output']: - bl_sockets = self._bl_sockets(direc) + active_socket_nametype = { + bl_socket.name: bl_socket.socket_type + for bl_socket in self._bl_sockets(direc) + } active_socket_defs = self.active_socket_defs(direc) # Determine Sockets to Remove ## -> Name: If the existing socket name isn't "active". ## -> Type: If the existing socket_type != "active" SocketDef. bl_sockets_to_remove = [ - bl_socket - for socket_name, bl_socket in bl_sockets.items() + active_sckname + for active_sckname, active_scktype in active_socket_nametype.items() if ( - socket_name not in active_socket_defs - or bl_socket.socket_type - is not active_socket_defs[socket_name].socket_type + active_sckname not in active_socket_defs + or active_scktype + is not active_socket_defs[active_sckname].socket_type ) ] @@ -394,39 +397,50 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): ## -> Name: If the existing socket name is "active". ## -> Type: If the existing socket_type == "active" SocketDef. ## -> Compare: If the existing socket differs from the SocketDef. + ## -> NOTE: Reload bl_sockets in case to-update scks were removed. bl_sockets_to_update = [ - bl_socket - for socket_name, bl_socket in bl_sockets.items() + active_sckname + for active_sckname, active_scktype in active_socket_nametype.items() if ( - socket_name in active_socket_defs - and bl_socket.socket_type - is active_socket_defs[socket_name].socket_type - and not active_socket_defs[socket_name].compare(bl_socket) + active_sckname in active_socket_defs + and active_scktype is active_socket_defs[active_sckname].socket_type + and not active_socket_defs[active_sckname].compare( + self._bl_sockets(direc)[active_sckname] + ) ) ] # Remove Sockets - for bl_socket in bl_sockets_to_remove: - bl_socket_name = bl_socket.name + ## -> The symptom of using a deleted socket is... hard crash. + ## -> Therefore, we must be EXTREMELY careful with bl_socket refs. + ## -> The multi-stage for-loop helps us guard from deleted sockets. + for active_sckname in bl_sockets_to_remove: + bl_socket = self._bl_sockets(direc).get(active_sckname) # 1. Report the socket removal to the NodeTree. ## -> The NodeLinkCache needs to be adjusted manually. node_tree.on_node_socket_removed(bl_socket) + for active_sckname in bl_sockets_to_remove: + bl_sockets = self._bl_sockets(direc) + bl_socket = bl_sockets.get(active_sckname) + # 2. Perform the removal using Blender's API. ## -> Actually removes the socket. - bl_sockets.remove(bl_socket) - - # 3. Invalidate the input socket cache across all kinds. - ## -> Prevents phantom values from remaining available. - ## -> Done after socket removal to protect from race condition. - self._compute_input.invalidate( - input_socket_name=bl_socket_name, - kind=..., - unit_system=..., - ) + ## -> Must be protected from auto-removed use-after-free. + if bl_socket is not None: + bl_sockets.remove(bl_socket) if direc == 'input': + # 3. Invalidate the input socket cache across all kinds. + ## -> Prevents phantom values from remaining available. + ## -> Done after socket removal to protect from race condition. + self._compute_input.invalidate( + input_socket_name=active_sckname, + kind=..., + unit_system=..., + ) + # 4. Run all trigger-only `on_value_changed` callbacks. ## -> Runs any event methods that relied on the socket. ## -> Only methods that don't **require** the socket. @@ -435,39 +449,69 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance): triggered_event_methods = [ event_method for event_method in self.filtered_event_methods_by_event( - ct.FlowEvent.DataChanged, (bl_socket_name, None, None) + ct.FlowEvent.DataChanged, (active_sckname, None, None) ) - if bl_socket_name + if active_sckname not in event_method.callback_info.must_load_sockets ] for event_method in triggered_event_methods: event_method(self) + else: + # 3. Invalidate the output socket cache across all kinds. + ## -> Prevents phantom values from remaining available. + ## -> Done after socket removal to protect from race condition. + self.compute_output.invalidate( + input_socket_name=active_sckname, + kind=..., + ) + # Update Sockets - for bl_socket in bl_sockets_to_update: - bl_socket_name = bl_socket.name - socket_def = active_socket_defs[bl_socket_name] + ## -> The symptom of using a deleted socket is... hard crash. + ## -> Therefore, we must be EXTREMELY careful with bl_socket refs. + ## -> The multi-stage for-loop helps us guard from deleted sockets. + for active_sckname in bl_sockets_to_update: + bl_sockets = self._bl_sockets(direc) + bl_socket = bl_sockets.get(active_sckname) - # 1. Pretend to Initialize for the First Time - ## -> NOTE: The socket's caches will be completely regenerated. - ## -> NOTE: A full FlowKind update will occur, but only one. - bl_socket.is_initializing = True - socket_def.preinit(bl_socket) - socket_def.init(bl_socket) - socket_def.postinit(bl_socket) + if bl_socket is not None: + socket_def = active_socket_defs[active_sckname] - # 2. Re-Test Socket Capabilities - ## -> Factors influencing CapabilitiesFlow may have changed. - ## -> Therefore, we must re-test all link capabilities. - bl_socket.remove_invalidated_links() + # 1. Pretend to Initialize for the First Time + ## -> NOTE: The socket's caches will be completely regenerated. + ## -> NOTE: A full FlowKind update will occur, but only one. + bl_socket.is_initializing = True + socket_def.preinit(bl_socket) + socket_def.init(bl_socket) + socket_def.postinit(bl_socket) - # 3. Invalidate the input socket cache across all kinds. - ## -> Prevents phantom values from remaining available. - self._compute_input.invalidate( - input_socket_name=bl_socket_name, - kind=..., - unit_system=..., - ) + for active_sckname in bl_sockets_to_update: + bl_sockets = self._bl_sockets(direc) + bl_socket = bl_sockets.get(active_sckname) + + if bl_socket is not None: + # 2. Re-Test Socket Capabilities + ## -> Factors influencing CapabilitiesFlow may have changed. + ## -> Therefore, we must re-test all link capabilities. + bl_socket.remove_invalidated_links() + + if direc == 'input': + # 3. Invalidate the input socket cache across all kinds. + ## -> Prevents phantom values from remaining available. + self._compute_input.invalidate( + input_socket_name=active_sckname, + kind=..., + unit_system=..., + ) + + if direc == 'output': + # 3. Invalidate the output socket cache across all kinds. + ## -> Prevents phantom values from remaining available. + ## -> Done after socket removal to protect from race condition. + self.compute_output.invalidate( + input_socket_name=active_sckname, + kind=..., + ) def _add_new_active_sockets(self): """Add and initialize all "active" sockets that aren't on the node. diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py index 9dc93bd..db9df67 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/symbol_constant.py @@ -104,6 +104,11 @@ class SymbolConstantNode(base.MaxwellSimNode): interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True)) interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True)) + preview_value_z: int = bl_cache.BLField(0) + preview_value_q: tuple[int, int] = bl_cache.BLField((0, 1)) + preview_value_re: float = bl_cache.BLField(0.0) + preview_value_im: float = bl_cache.BLField(0.0) + #################### # - Computed Properties #################### @@ -122,6 +127,10 @@ class SymbolConstantNode(base.MaxwellSimNode): 'interval_finite_im', 'interval_inf_im', 'interval_closed_im', + 'preview_value_z', + 'preview_value_q', + 'preview_value_re', + 'preview_value_im', } ) def symbol(self) -> sim_symbols.SimSymbol: @@ -140,6 +149,10 @@ class SymbolConstantNode(base.MaxwellSimNode): interval_finite_im=self.interval_finite_im, interval_inf_im=self.interval_inf_im, interval_closed_im=self.interval_closed_im, + preview_value_z=self.preview_value_z, + preview_value_q=self.preview_value_q, + preview_value_re=self.preview_value_re, + preview_value_im=self.preview_value_im, ) #################### @@ -164,36 +177,68 @@ class SymbolConstantNode(base.MaxwellSimNode): col.prop(self, self.blfields['physical_type'], text='') + # Domain - Infinite row = col.row(align=True) row.alignment = 'CENTER' - row.label(text='Domain') + row.label(text='Domain - Is Infinite') + row = col.row(align=True) + if self.mathtype is spux.MathType.Complex: + row.prop(self, self.blfields['interval_inf'], text='ℝ') + row.prop(self, self.blfields['interval_inf_im'], text='𝕀') + else: + row.prop(self, self.blfields['interval_inf'], text='') + + if any(not b for b in self.interval_inf): + # Domain - Closure + row = col.row(align=True) + row.alignment = 'CENTER' + row.label(text='Domain - Closure') + + row = col.row(align=True) + if self.mathtype is spux.MathType.Complex: + row.prop(self, self.blfields['interval_closed'], text='ℝ') + row.prop(self, self.blfields['interval_closed_im'], text='𝕀') + else: + row.prop(self, self.blfields['interval_closed'], text='') + + # Domain - Finite + row = col.row(align=True) + row.alignment = 'CENTER' + row.label(text='Domain - Interval') + + row = col.row(align=True) + match self.mathtype: + case spux.MathType.Integer: + row.prop(self, self.blfields['interval_finite_z'], text='') + + case spux.MathType.Rational: + row.prop(self, self.blfields['interval_finite_q'], text='') + + case spux.MathType.Real: + row.prop(self, self.blfields['interval_finite_re'], text='') + + case spux.MathType.Complex: + row.prop(self, self.blfields['interval_finite_re'], text='ℝ') + row.prop(self, self.blfields['interval_finite_im'], text='𝕀') + + # Domain - Closure + row = col.row(align=True) + row.alignment = 'CENTER' + row.label(text='Preview Value') match self.mathtype: case spux.MathType.Integer: - col.prop(self, self.blfields['interval_finite_z'], text='') - col.prop(self, self.blfields['interval_inf'], text='Infinite') - col.prop(self, self.blfields['interval_closed'], text='Closed') + row.prop(self, self.blfields['preview_value_z'], text='') case spux.MathType.Rational: - col.prop(self, self.blfields['interval_finite_q'], text='') - col.prop(self, self.blfields['interval_inf'], text='Infinite') - col.prop(self, self.blfields['interval_closed'], text='Closed') + row.prop(self, self.blfields['preview_value_q'], text='') case spux.MathType.Real: - col.prop(self, self.blfields['interval_finite_re'], text='') - col.prop(self, self.blfields['interval_inf'], text='Infinite') - col.prop(self, self.blfields['interval_closed'], text='Closed') + row.prop(self, self.blfields['preview_value_re'], text='') case spux.MathType.Complex: - col.prop(self, self.blfields['interval_finite_re'], text='ℝ') - col.prop(self, self.blfields['interval_inf'], text='ℝ Infinite') - col.prop(self, self.blfields['interval_closed'], text='ℝ Closed') - - col.separator() - - col.prop(self, self.blfields['interval_finite_im'], text='𝕀') - col.prop(self, self.blfields['interval_inf'], text='𝕀 Infinite') - col.prop(self, self.blfields['interval_closed'], text='𝕀 Closed') + row.prop(self, self.blfields['preview_value_re'], text='ℝ') + row.prop(self, self.blfields['preview_value_im'], text='𝕀') #################### # - FlowKinds diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py index 4f8a0bb..81034da 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/simulations/sim_domain.py @@ -132,7 +132,7 @@ class SimDomainNode(base.MaxwellSimNode): | grid | medium ).compose_within( - enclosing_func=lambda els: { + lambda els: { 'run_time': els[0], 'center': tuple(els[1].flatten()), 'size': tuple(els[2].flatten()), diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py index 0d955df..edd2d95 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/box_structure.py @@ -88,36 +88,19 @@ class BoxStructureNode(base.MaxwellSimNode): 'Structure', kind=ct.FlowKind.Value, # Loaded - input_sockets={'Medium', 'Center', 'Size'}, output_sockets={'Structure'}, - output_socket_kinds={'Structure': ct.FlowKind.Params}, + output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}}, ) - def compute_value(self, input_sockets, output_sockets) -> td.Box: - """Compute a single box structure object, given that all inputs are non-symbolic.""" - center = input_sockets['Center'] - size = input_sockets['Size'] - medium = input_sockets['Medium'] - output_params = output_sockets['Structure'] + def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute the particular value of the simulation domain from strictly non-symbolic inputs.""" + output_func = output_sockets['Structure'][ct.FlowKind.Func] + output_params = output_sockets['Structure'][ct.FlowKind.Params] - has_center = not ct.FlowSignal.check(center) - has_size = not ct.FlowSignal.check(size) - has_medium = not ct.FlowSignal.check(medium) + has_output_func = not ct.FlowSignal.check(output_func) has_output_params = not ct.FlowSignal.check(output_params) - if ( - has_center - and has_size - and has_medium - and has_output_params - and not output_params.symbols - ): - return td.Structure( - geometry=td.Box( - center=spux.scale_to_unit_system(center, ct.UNITS_TIDY3D), - size=spux.scale_to_unit_system(size, ct.UNITS_TIDY3D), - ), - medium=medium, - ) + if has_output_func and has_output_params and not output_params.symbols: + return output_func.realize(output_params, disallow_jax=True) return ct.FlowSignal.FlowPending #################### @@ -134,45 +117,43 @@ class BoxStructureNode(base.MaxwellSimNode): 'Center': ct.FlowKind.Func, 'Size': ct.FlowKind.Func, }, - output_sockets={'Structure'}, - output_socket_kinds={'Structure': ct.FlowKind.Params}, ) - def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box: + def compute_structure_func(self, props, input_sockets) -> td.Box: """Compute a possibly-differentiable function, producing a box structure from the input parameters.""" - output_params = output_sockets['Structure'] center = input_sockets['Center'] size = input_sockets['Size'] medium = input_sockets['Medium'] - has_output_params = not ct.FlowSignal.check(output_params) has_center = not ct.FlowSignal.check(center) has_size = not ct.FlowSignal.check(size) has_medium = not ct.FlowSignal.check(medium) - if has_output_params and has_center and has_size and has_medium: + if has_center and has_size and has_medium: differentiable = props['differentiable'] if differentiable: - return (center | size | medium).compose_within( - enclosing_func=lambda els: tdadj.JaxStructure( + return ( + center.scale_to_unit_system(ct.UNITS_TIDY3D) + | size.scale_to_unit_system(ct.UNITS_TIDY3D) + | medium + ).compose_within( + lambda els: tdadj.JaxStructure( geometry=tdadj.JaxBox( - center=tuple(els[0].flatten()), - size=tuple(els[1].flatten()), + center_jax=tuple(els[0].flatten()), + size_jax=tuple(els[1].flatten()), ), medium=els[2], ), supports_jax=True, ) - return (center | size | medium).compose_within( - ## TODO: Unit conversion within the composed function?? - ## -- We do need Tidy3D to be given ex. micrometers in particular. - ## -- But the previous numerical output might not be micrometers. - ## -- There must be a way to add a conversion in, without strangeness. - ## -- Ex. can compose_within() take a unit system? - ## -- This would require - enclosing_func=lambda els: td.Structure( + return ( + center.scale_to_unit_system(ct.UNITS_TIDY3D) + | size.scale_to_unit_system(ct.UNITS_TIDY3D) + | medium + ).compose_within( + lambda els: td.Structure( geometry=td.Box( - center=tuple(els[0].flatten()), - size=tuple(els[1].flatten()), + center=els[0].flatten().tolist(), + size=els[1].flatten().tolist(), ), medium=els[2], ), @@ -187,7 +168,6 @@ class BoxStructureNode(base.MaxwellSimNode): 'Structure', kind=ct.FlowKind.Params, # Loaded - props={'differentiable'}, input_sockets={'Medium', 'Center', 'Size'}, input_socket_kinds={ 'Medium': ct.FlowKind.Params, @@ -195,7 +175,7 @@ class BoxStructureNode(base.MaxwellSimNode): 'Size': ct.FlowKind.Params, }, ) - def compute_params(self, props, input_sockets) -> td.Box: + def compute_params(self, input_sockets) -> td.Box: center = input_sockets['Center'] size = input_sockets['Size'] medium = input_sockets['Medium'] @@ -238,13 +218,16 @@ class BoxStructureNode(base.MaxwellSimNode): output_sockets={'Structure'}, output_socket_kinds={'Structure': ct.FlowKind.Params}, ) - def on_inputs_changed(self, managed_objs, input_sockets, output_sockets): - output_params = output_sockets['Structure'] + def on_previewable_changed(self, managed_objs, input_sockets, output_sockets): center = input_sockets['Center'] + size = input_sockets['Size'] + output_params = output_sockets['Structure'] - has_output_params = not ct.FlowSignal.check(output_params) has_center = not ct.FlowSignal.check(center) - if has_center and has_output_params and not output_params.symbols: + has_size = not ct.FlowSignal.check(size) + has_output_params = not ct.FlowSignal.check(output_params) + + if has_center and has_size and has_output_params and not output_params.symbols: ## TODO: There are strategies for handling examples of symbol values. # Push Loose Input Values to GeoNodes Modifier @@ -254,7 +237,7 @@ class BoxStructureNode(base.MaxwellSimNode): 'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox), 'unit_system': ct.UNITS_BLENDER, 'inputs': { - 'Size': input_sockets['Size'], + 'Size': size, }, }, location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER), diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py index c1f7440..a43540e 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/structures/primitives/cylinder_structure.py @@ -32,6 +32,8 @@ log = logger.get(__name__) class CylinderStructureNode(base.MaxwellSimNode): + """A generic cylinder structure with configurable radius and height.""" + node_type = ct.NodeType.CylinderStructure bl_label = 'Cylinder Structure' use_sim_node_name = True @@ -64,27 +66,101 @@ class CylinderStructureNode(base.MaxwellSimNode): } #################### - # - Output Socket Computation + # - FlowKind.Value #################### @events.computes_output_socket( 'Structure', + kind=ct.FlowKind.Value, + # Loaded + output_sockets={'Structure'}, + output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}}, + ) + def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal: + """Compute the particular value of the simulation domain from strictly non-symbolic inputs.""" + output_func = output_sockets['Structure'][ct.FlowKind.Func] + output_params = output_sockets['Structure'][ct.FlowKind.Params] + + has_output_func = not ct.FlowSignal.check(output_func) + has_output_params = not ct.FlowSignal.check(output_params) + + if has_output_func and has_output_params and not output_params.symbols: + return output_func.realize(output_params, disallow_jax=True) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Func + #################### + @events.computes_output_socket( + 'Structure', + kind=ct.FlowKind.Func, + # Loaded input_sockets={'Center', 'Radius', 'Medium', 'Height'}, - unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, - scale_input_sockets={ - 'Center': 'Tidy3DUnits', - 'Radius': 'Tidy3DUnits', - 'Height': 'Tidy3DUnits', + input_socket_kinds={ + 'Center': ct.FlowKind.Func, + 'Radius': ct.FlowKind.Func, + 'Height': ct.FlowKind.Func, + 'Medium': ct.FlowKind.Func, }, ) - def compute_structure(self, input_sockets, unit_systems) -> td.Box: - return td.Structure( - geometry=td.Cylinder( - radius=input_sockets['Radius'], - center=input_sockets['Center'], - length=input_sockets['Height'], - ), - medium=input_sockets['Medium'], - ) + def compute_func(self, input_sockets) -> td.Box: + """Compute a single cylinder structure object, given that all inputs are non-symbolic.""" + center = input_sockets['Center'] + radius = input_sockets['Radius'] + height = input_sockets['Height'] + medium = input_sockets['Medium'] + + has_center = not ct.FlowSignal.check(center) + has_radius = not ct.FlowSignal.check(radius) + has_height = not ct.FlowSignal.check(height) + has_medium = not ct.FlowSignal.check(medium) + + if has_center and has_radius and has_height and has_medium: + return ( + center.scale_to_unit_system(ct.UNITS_TIDY3D) + | radius.scale_to_unit_system(ct.UNITS_TIDY3D) + | height.scale_to_unit_system(ct.UNITS_TIDY3D) + | medium + ).compose_within( + lambda els: td.Structure( + geometry=td.Cylinder( + center=els[0].flatten().tolist(), + radius=els[1], + length=els[2], + ), + medium=els[3], + ) + ) + return ct.FlowSignal.FlowPending + + #################### + # - FlowKind.Params + #################### + @events.computes_output_socket( + 'Structure', + kind=ct.FlowKind.Params, + # Loaded + input_sockets={'Center', 'Radius', 'Medium', 'Height'}, + input_socket_kinds={ + 'Center': ct.FlowKind.Params, + 'Radius': ct.FlowKind.Params, + 'Height': ct.FlowKind.Params, + 'Medium': ct.FlowKind.Params, + }, + ) + def compute_params(self, input_sockets) -> td.Box: + center = input_sockets['Center'] + radius = input_sockets['Radius'] + height = input_sockets['Height'] + medium = input_sockets['Medium'] + + has_center = not ct.FlowSignal.check(center) + has_radius = not ct.FlowSignal.check(radius) + has_height = not ct.FlowSignal.check(height) + has_medium = not ct.FlowSignal.check(medium) + + if has_center and has_radius and has_height and has_medium: + return center | radius | height | medium + return ct.FlowSignal.FlowPending #################### # - Preview diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index 1cdc9d7..8b57f02 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -791,14 +791,7 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### # - FlowKind: Func (w/Params if Constant) #################### - @bl_cache.cached_bl_property( - depends_on={ - 'value', - 'sorted_sp_symbols', - 'sorted_symbols', - 'output_sym', - } - ) + @bl_cache.cached_bl_property(depends_on={'output_sym'}) def lazy_func(self) -> ct.FuncFlow: """Returns a lazy value that computes the expression returned by `self.value`. @@ -806,42 +799,12 @@ class ExprBLSocket(base.MaxwellSimSocket): Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`. """ if self.output_sym is not None: - match self.active_kind: - case ct.FlowKind.Value | ct.FlowKind.Func if ( - self.sorted_symbols and not ct.FlowSignal.check(self.value) - ): - return ct.FuncFlow( - func=sp.lambdify( - self.sorted_sp_symbols, - self.output_sym.conform(self.value, strip_unit=True), - 'jax', - ), - func_args=list(self.sorted_symbols), - func_output=self.output_sym, - supports_jax=True, - ) - - case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols: - return ct.FuncFlow( - func=lambda v: v, - func_args=[self.output_sym], - func_output=self.output_sym, - supports_jax=True, - ) - - case ct.FlowKind.Range if self.sorted_symbols: - msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty' - raise NotImplementedError(msg) - - case ct.FlowKind.Range if ( - not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) - ): - return ct.FuncFlow( - func=lambda v: v, - func_args=[self.output_sym], - func_output=self.output_sym, - supports_jax=True, - ) + return ct.FuncFlow( + func=lambda v: v, + func_args=[self.output_sym], + func_output=self.output_sym, + supports_jax=True, + ) return ct.FlowSignal.FlowPending @@ -854,22 +817,15 @@ class ExprBLSocket(base.MaxwellSimSocket): If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use). Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`. """ - output_sym = self.output_sym - if output_sym is not None: + if self.output_sym is not None: match self.active_kind: - case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols: - return ct.ParamsFlow( - arg_targets=list(self.sorted_symbols), - func_args=[sym.sp_symbol for sym in self.sorted_symbols], - symbols=set(self.sorted_symbols), - ) - case ct.FlowKind.Value | ct.FlowKind.Func if ( - not self.sorted_symbols and not ct.FlowSignal.check(self.value) + not ct.FlowSignal.check(self.value) ): return ct.ParamsFlow( arg_targets=[self.output_sym], func_args=[self.value], + symbols=set(self.sorted_symbols), ) case ct.FlowKind.Range if self.sorted_symbols: @@ -899,20 +855,14 @@ class ExprBLSocket(base.MaxwellSimSocket): Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along. """ - output_sym = self.output_sym - if output_sym is not None: + if self.output_sym is not None: match self.active_kind: - case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols: + case ct.FlowKind.Value | ct.FlowKind.Func: return ct.InfoFlow( dims={sym: None for sym in self.sorted_symbols}, output=self.output_sym, ) - case ct.FlowKind.Value | ct.FlowKind.Func if ( - not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range) - ): - return ct.InfoFlow(output=self.output_sym) - case ct.FlowKind.Range if self.sorted_symbols: msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty' raise NotImplementedError(msg) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py index 521605d..2c2119f 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/maxwell/medium.py @@ -14,7 +14,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import math + import bpy +import jax.numpy as jnp import scipy as sc import sympy.physics.units as spu import tidy3d as td @@ -28,9 +31,15 @@ from .. import base log = logger.get(__name__) -VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second - -FIXED_WL = 500 * spu.nm +_VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second +_FIXED_WL = 500 * spu.nm +FIXED_FREQ = spux.scale_to_unit_system( + spu.convert_to( + _VAC_SPEED_OF_LIGHT / _FIXED_WL, + spu.hertz, + ), + ct.UNITS_TIDY3D, +) class MaxwellMediumBLSocket(base.MaxwellSimSocket): @@ -49,24 +58,17 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket): #################### @bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'}) def value(self) -> td.Medium: - freq = ( - spu.convert_to( - VAC_SPEED_OF_LIGHT / FIXED_WL, - spu.hertz, - ) - / spu.hertz - ) + eps_r_re = self.eps_rel[0] + conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um? if self.differentiable: - return tdadj.JaxMedium.from_nk( - n=self.eps_rel[0], - k=self.eps_rel[1], - freq=freq, + return tdadj.JaxMedium( + permittivity_jax=jnp.array(eps_r_re, dtype=float), + conductivity_jax=jnp.array(conductivity, dtype=float), ) - return td.Medium.from_nk( - n=self.eps_rel[0], - k=self.eps_rel[1], - freq=freq, + return td.Medium( + permittivity=eps_r_re, + conductivity=conductivity, ) @value.setter diff --git a/src/blender_maxwell/utils/sim_symbols.py b/src/blender_maxwell/utils/sim_symbols.py index 1efaec3..46d4ce3 100644 --- a/src/blender_maxwell/utils/sim_symbols.py +++ b/src/blender_maxwell/utils/sim_symbols.py @@ -263,6 +263,11 @@ class SimSymbol(pyd.BaseModel): interval_inf_im: tuple[bool, bool] = (True, True) interval_closed_im: tuple[bool, bool] = (False, False) + preview_value_z: int = 0 + preview_value_q: tuple[int, int] = (0, 1) + preview_value_re: float = 0.0 + preview_value_im: float = 0.0 + #################### # - Core ####################