fix: use-after-free in socket pruner

Also the usual batch of improvements.
Differentiability is misbehaving, intriguingly.
main
Sofus Albert Høgsbro Rose 2024-05-30 21:39:48 +02:00
parent 38e70a60d3
commit c286d65193
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
9 changed files with 328 additions and 215 deletions

View File

@ -325,8 +325,16 @@ class FuncFlow(pyd.BaseModel):
symbol_values: dict[sim_symbols.SimSymbol, spux.SympyExpr] = MappingProxyType(
{}
),
disallow_jax: bool = True,
) -> typ.Self:
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols."""
"""Run the represented function with the best optimization available, given particular choices for all function arguments and for all unrealized symbols.
Parameters:
params: The parameter-tracking object from which function arguments will be computed.
symbol_values: Values for all `SimSymbol`s that are not yet realized in `params`.
disallow_jax: Don't use `self.func_jax` to evaluate, even if possible.
This is desirable when the overhead of `jax.jit()` is known in advance to exceed the performance benefits.
"""
if self.supports_jax:
return self.func_jax(
*params.scaled_func_args(symbol_values),

View File

@ -374,19 +374,22 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
"""
node_tree = self.id_data
for direc in ['input', 'output']:
bl_sockets = self._bl_sockets(direc)
active_socket_nametype = {
bl_socket.name: bl_socket.socket_type
for bl_socket in self._bl_sockets(direc)
}
active_socket_defs = self.active_socket_defs(direc)
# Determine Sockets to Remove
## -> Name: If the existing socket name isn't "active".
## -> Type: If the existing socket_type != "active" SocketDef.
bl_sockets_to_remove = [
bl_socket
for socket_name, bl_socket in bl_sockets.items()
active_sckname
for active_sckname, active_scktype in active_socket_nametype.items()
if (
socket_name not in active_socket_defs
or bl_socket.socket_type
is not active_socket_defs[socket_name].socket_type
active_sckname not in active_socket_defs
or active_scktype
is not active_socket_defs[active_sckname].socket_type
)
]
@ -394,39 +397,50 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
## -> Name: If the existing socket name is "active".
## -> Type: If the existing socket_type == "active" SocketDef.
## -> Compare: If the existing socket differs from the SocketDef.
## -> NOTE: Reload bl_sockets in case to-update scks were removed.
bl_sockets_to_update = [
bl_socket
for socket_name, bl_socket in bl_sockets.items()
active_sckname
for active_sckname, active_scktype in active_socket_nametype.items()
if (
socket_name in active_socket_defs
and bl_socket.socket_type
is active_socket_defs[socket_name].socket_type
and not active_socket_defs[socket_name].compare(bl_socket)
active_sckname in active_socket_defs
and active_scktype is active_socket_defs[active_sckname].socket_type
and not active_socket_defs[active_sckname].compare(
self._bl_sockets(direc)[active_sckname]
)
)
]
# Remove Sockets
for bl_socket in bl_sockets_to_remove:
bl_socket_name = bl_socket.name
## -> The symptom of using a deleted socket is... hard crash.
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
## -> The multi-stage for-loop helps us guard from deleted sockets.
for active_sckname in bl_sockets_to_remove:
bl_socket = self._bl_sockets(direc).get(active_sckname)
# 1. Report the socket removal to the NodeTree.
## -> The NodeLinkCache needs to be adjusted manually.
node_tree.on_node_socket_removed(bl_socket)
for active_sckname in bl_sockets_to_remove:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
# 2. Perform the removal using Blender's API.
## -> Actually removes the socket.
bl_sockets.remove(bl_socket)
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self._compute_input.invalidate(
input_socket_name=bl_socket_name,
kind=...,
unit_system=...,
)
## -> Must be protected from auto-removed use-after-free.
if bl_socket is not None:
bl_sockets.remove(bl_socket)
if direc == 'input':
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self._compute_input.invalidate(
input_socket_name=active_sckname,
kind=...,
unit_system=...,
)
# 4. Run all trigger-only `on_value_changed` callbacks.
## -> Runs any event methods that relied on the socket.
## -> Only methods that don't **require** the socket.
@ -435,39 +449,69 @@ class MaxwellSimNode(bpy.types.Node, bl_instance.BLInstance):
triggered_event_methods = [
event_method
for event_method in self.filtered_event_methods_by_event(
ct.FlowEvent.DataChanged, (bl_socket_name, None, None)
ct.FlowEvent.DataChanged, (active_sckname, None, None)
)
if bl_socket_name
if active_sckname
not in event_method.callback_info.must_load_sockets
]
for event_method in triggered_event_methods:
event_method(self)
else:
# 3. Invalidate the output socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
)
# Update Sockets
for bl_socket in bl_sockets_to_update:
bl_socket_name = bl_socket.name
socket_def = active_socket_defs[bl_socket_name]
## -> The symptom of using a deleted socket is... hard crash.
## -> Therefore, we must be EXTREMELY careful with bl_socket refs.
## -> The multi-stage for-loop helps us guard from deleted sockets.
for active_sckname in bl_sockets_to_update:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
# 1. Pretend to Initialize for the First Time
## -> NOTE: The socket's caches will be completely regenerated.
## -> NOTE: A full FlowKind update will occur, but only one.
bl_socket.is_initializing = True
socket_def.preinit(bl_socket)
socket_def.init(bl_socket)
socket_def.postinit(bl_socket)
if bl_socket is not None:
socket_def = active_socket_defs[active_sckname]
# 2. Re-Test Socket Capabilities
## -> Factors influencing CapabilitiesFlow may have changed.
## -> Therefore, we must re-test all link capabilities.
bl_socket.remove_invalidated_links()
# 1. Pretend to Initialize for the First Time
## -> NOTE: The socket's caches will be completely regenerated.
## -> NOTE: A full FlowKind update will occur, but only one.
bl_socket.is_initializing = True
socket_def.preinit(bl_socket)
socket_def.init(bl_socket)
socket_def.postinit(bl_socket)
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
self._compute_input.invalidate(
input_socket_name=bl_socket_name,
kind=...,
unit_system=...,
)
for active_sckname in bl_sockets_to_update:
bl_sockets = self._bl_sockets(direc)
bl_socket = bl_sockets.get(active_sckname)
if bl_socket is not None:
# 2. Re-Test Socket Capabilities
## -> Factors influencing CapabilitiesFlow may have changed.
## -> Therefore, we must re-test all link capabilities.
bl_socket.remove_invalidated_links()
if direc == 'input':
# 3. Invalidate the input socket cache across all kinds.
## -> Prevents phantom values from remaining available.
self._compute_input.invalidate(
input_socket_name=active_sckname,
kind=...,
unit_system=...,
)
if direc == 'output':
# 3. Invalidate the output socket cache across all kinds.
## -> Prevents phantom values from remaining available.
## -> Done after socket removal to protect from race condition.
self.compute_output.invalidate(
input_socket_name=active_sckname,
kind=...,
)
def _add_new_active_sockets(self):
"""Add and initialize all "active" sockets that aren't on the node.

View File

@ -104,6 +104,11 @@ class SymbolConstantNode(base.MaxwellSimNode):
interval_inf_im: tuple[bool, bool] = bl_cache.BLField((True, True))
interval_closed_im: tuple[bool, bool] = bl_cache.BLField((True, True))
preview_value_z: int = bl_cache.BLField(0)
preview_value_q: tuple[int, int] = bl_cache.BLField((0, 1))
preview_value_re: float = bl_cache.BLField(0.0)
preview_value_im: float = bl_cache.BLField(0.0)
####################
# - Computed Properties
####################
@ -122,6 +127,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
'interval_finite_im',
'interval_inf_im',
'interval_closed_im',
'preview_value_z',
'preview_value_q',
'preview_value_re',
'preview_value_im',
}
)
def symbol(self) -> sim_symbols.SimSymbol:
@ -140,6 +149,10 @@ class SymbolConstantNode(base.MaxwellSimNode):
interval_finite_im=self.interval_finite_im,
interval_inf_im=self.interval_inf_im,
interval_closed_im=self.interval_closed_im,
preview_value_z=self.preview_value_z,
preview_value_q=self.preview_value_q,
preview_value_re=self.preview_value_re,
preview_value_im=self.preview_value_im,
)
####################
@ -164,36 +177,68 @@ class SymbolConstantNode(base.MaxwellSimNode):
col.prop(self, self.blfields['physical_type'], text='')
# Domain - Infinite
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Domain')
row.label(text='Domain - Is Infinite')
row = col.row(align=True)
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_inf'], text='')
row.prop(self, self.blfields['interval_inf_im'], text='𝕀')
else:
row.prop(self, self.blfields['interval_inf'], text='')
if any(not b for b in self.interval_inf):
# Domain - Closure
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Domain - Closure')
row = col.row(align=True)
if self.mathtype is spux.MathType.Complex:
row.prop(self, self.blfields['interval_closed'], text='')
row.prop(self, self.blfields['interval_closed_im'], text='𝕀')
else:
row.prop(self, self.blfields['interval_closed'], text='')
# Domain - Finite
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Domain - Interval')
row = col.row(align=True)
match self.mathtype:
case spux.MathType.Integer:
row.prop(self, self.blfields['interval_finite_z'], text='')
case spux.MathType.Rational:
row.prop(self, self.blfields['interval_finite_q'], text='')
case spux.MathType.Real:
row.prop(self, self.blfields['interval_finite_re'], text='')
case spux.MathType.Complex:
row.prop(self, self.blfields['interval_finite_re'], text='')
row.prop(self, self.blfields['interval_finite_im'], text='𝕀')
# Domain - Closure
row = col.row(align=True)
row.alignment = 'CENTER'
row.label(text='Preview Value')
match self.mathtype:
case spux.MathType.Integer:
col.prop(self, self.blfields['interval_finite_z'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
row.prop(self, self.blfields['preview_value_z'], text='')
case spux.MathType.Rational:
col.prop(self, self.blfields['interval_finite_q'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
row.prop(self, self.blfields['preview_value_q'], text='')
case spux.MathType.Real:
col.prop(self, self.blfields['interval_finite_re'], text='')
col.prop(self, self.blfields['interval_inf'], text='Infinite')
col.prop(self, self.blfields['interval_closed'], text='Closed')
row.prop(self, self.blfields['preview_value_re'], text='')
case spux.MathType.Complex:
col.prop(self, self.blfields['interval_finite_re'], text='')
col.prop(self, self.blfields['interval_inf'], text=' Infinite')
col.prop(self, self.blfields['interval_closed'], text=' Closed')
col.separator()
col.prop(self, self.blfields['interval_finite_im'], text='𝕀')
col.prop(self, self.blfields['interval_inf'], text='𝕀 Infinite')
col.prop(self, self.blfields['interval_closed'], text='𝕀 Closed')
row.prop(self, self.blfields['preview_value_re'], text='')
row.prop(self, self.blfields['preview_value_im'], text='𝕀')
####################
# - FlowKinds

View File

@ -132,7 +132,7 @@ class SimDomainNode(base.MaxwellSimNode):
| grid
| medium
).compose_within(
enclosing_func=lambda els: {
lambda els: {
'run_time': els[0],
'center': tuple(els[1].flatten()),
'size': tuple(els[2].flatten()),

View File

@ -88,36 +88,19 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure',
kind=ct.FlowKind.Value,
# Loaded
input_sockets={'Medium', 'Center', 'Size'},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_value(self, input_sockets, output_sockets) -> td.Box:
"""Compute a single box structure object, given that all inputs are non-symbolic."""
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
output_params = output_sockets['Structure']
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][ct.FlowKind.Func]
output_params = output_sockets['Structure'][ct.FlowKind.Params]
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if (
has_center
and has_size
and has_medium
and has_output_params
and not output_params.symbols
):
return td.Structure(
geometry=td.Box(
center=spux.scale_to_unit_system(center, ct.UNITS_TIDY3D),
size=spux.scale_to_unit_system(size, ct.UNITS_TIDY3D),
),
medium=medium,
)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
####################
@ -134,45 +117,43 @@ class BoxStructureNode(base.MaxwellSimNode):
'Center': ct.FlowKind.Func,
'Size': ct.FlowKind.Func,
},
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def compute_structure_func(self, props, input_sockets, output_sockets) -> td.Box:
def compute_structure_func(self, props, input_sockets) -> td.Box:
"""Compute a possibly-differentiable function, producing a box structure from the input parameters."""
output_params = output_sockets['Structure']
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
has_size = not ct.FlowSignal.check(size)
has_medium = not ct.FlowSignal.check(medium)
if has_output_params and has_center and has_size and has_medium:
if has_center and has_size and has_medium:
differentiable = props['differentiable']
if differentiable:
return (center | size | medium).compose_within(
enclosing_func=lambda els: tdadj.JaxStructure(
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: tdadj.JaxStructure(
geometry=tdadj.JaxBox(
center=tuple(els[0].flatten()),
size=tuple(els[1].flatten()),
center_jax=tuple(els[0].flatten()),
size_jax=tuple(els[1].flatten()),
),
medium=els[2],
),
supports_jax=True,
)
return (center | size | medium).compose_within(
## TODO: Unit conversion within the composed function??
## -- We do need Tidy3D to be given ex. micrometers in particular.
## -- But the previous numerical output might not be micrometers.
## -- There must be a way to add a conversion in, without strangeness.
## -- Ex. can compose_within() take a unit system?
## -- This would require
enclosing_func=lambda els: td.Structure(
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| size.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: td.Structure(
geometry=td.Box(
center=tuple(els[0].flatten()),
size=tuple(els[1].flatten()),
center=els[0].flatten().tolist(),
size=els[1].flatten().tolist(),
),
medium=els[2],
),
@ -187,7 +168,6 @@ class BoxStructureNode(base.MaxwellSimNode):
'Structure',
kind=ct.FlowKind.Params,
# Loaded
props={'differentiable'},
input_sockets={'Medium', 'Center', 'Size'},
input_socket_kinds={
'Medium': ct.FlowKind.Params,
@ -195,7 +175,7 @@ class BoxStructureNode(base.MaxwellSimNode):
'Size': ct.FlowKind.Params,
},
)
def compute_params(self, props, input_sockets) -> td.Box:
def compute_params(self, input_sockets) -> td.Box:
center = input_sockets['Center']
size = input_sockets['Size']
medium = input_sockets['Medium']
@ -238,13 +218,16 @@ class BoxStructureNode(base.MaxwellSimNode):
output_sockets={'Structure'},
output_socket_kinds={'Structure': ct.FlowKind.Params},
)
def on_inputs_changed(self, managed_objs, input_sockets, output_sockets):
output_params = output_sockets['Structure']
def on_previewable_changed(self, managed_objs, input_sockets, output_sockets):
center = input_sockets['Center']
size = input_sockets['Size']
output_params = output_sockets['Structure']
has_output_params = not ct.FlowSignal.check(output_params)
has_center = not ct.FlowSignal.check(center)
if has_center and has_output_params and not output_params.symbols:
has_size = not ct.FlowSignal.check(size)
has_output_params = not ct.FlowSignal.check(output_params)
if has_center and has_size and has_output_params and not output_params.symbols:
## TODO: There are strategies for handling examples of symbol values.
# Push Loose Input Values to GeoNodes Modifier
@ -254,7 +237,7 @@ class BoxStructureNode(base.MaxwellSimNode):
'node_group': import_geonodes(GeoNodes.StructurePrimitiveBox),
'unit_system': ct.UNITS_BLENDER,
'inputs': {
'Size': input_sockets['Size'],
'Size': size,
},
},
location=spux.scale_to_unit_system(center, ct.UNITS_BLENDER),

View File

@ -32,6 +32,8 @@ log = logger.get(__name__)
class CylinderStructureNode(base.MaxwellSimNode):
"""A generic cylinder structure with configurable radius and height."""
node_type = ct.NodeType.CylinderStructure
bl_label = 'Cylinder Structure'
use_sim_node_name = True
@ -64,27 +66,101 @@ class CylinderStructureNode(base.MaxwellSimNode):
}
####################
# - Output Socket Computation
# - FlowKind.Value
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Value,
# Loaded
output_sockets={'Structure'},
output_socket_kinds={'Structure': {ct.FlowKind.Func, ct.FlowKind.Params}},
)
def compute_value(self, output_sockets) -> ct.ParamsFlow | ct.FlowSignal:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
output_func = output_sockets['Structure'][ct.FlowKind.Func]
output_params = output_sockets['Structure'][ct.FlowKind.Params]
has_output_func = not ct.FlowSignal.check(output_func)
has_output_params = not ct.FlowSignal.check(output_params)
if has_output_func and has_output_params and not output_params.symbols:
return output_func.realize(output_params, disallow_jax=True)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Func
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Func,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D},
scale_input_sockets={
'Center': 'Tidy3DUnits',
'Radius': 'Tidy3DUnits',
'Height': 'Tidy3DUnits',
input_socket_kinds={
'Center': ct.FlowKind.Func,
'Radius': ct.FlowKind.Func,
'Height': ct.FlowKind.Func,
'Medium': ct.FlowKind.Func,
},
)
def compute_structure(self, input_sockets, unit_systems) -> td.Box:
return td.Structure(
geometry=td.Cylinder(
radius=input_sockets['Radius'],
center=input_sockets['Center'],
length=input_sockets['Height'],
),
medium=input_sockets['Medium'],
)
def compute_func(self, input_sockets) -> td.Box:
"""Compute a single cylinder structure object, given that all inputs are non-symbolic."""
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return (
center.scale_to_unit_system(ct.UNITS_TIDY3D)
| radius.scale_to_unit_system(ct.UNITS_TIDY3D)
| height.scale_to_unit_system(ct.UNITS_TIDY3D)
| medium
).compose_within(
lambda els: td.Structure(
geometry=td.Cylinder(
center=els[0].flatten().tolist(),
radius=els[1],
length=els[2],
),
medium=els[3],
)
)
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'Structure',
kind=ct.FlowKind.Params,
# Loaded
input_sockets={'Center', 'Radius', 'Medium', 'Height'},
input_socket_kinds={
'Center': ct.FlowKind.Params,
'Radius': ct.FlowKind.Params,
'Height': ct.FlowKind.Params,
'Medium': ct.FlowKind.Params,
},
)
def compute_params(self, input_sockets) -> td.Box:
center = input_sockets['Center']
radius = input_sockets['Radius']
height = input_sockets['Height']
medium = input_sockets['Medium']
has_center = not ct.FlowSignal.check(center)
has_radius = not ct.FlowSignal.check(radius)
has_height = not ct.FlowSignal.check(height)
has_medium = not ct.FlowSignal.check(medium)
if has_center and has_radius and has_height and has_medium:
return center | radius | height | medium
return ct.FlowSignal.FlowPending
####################
# - Preview

View File

@ -791,14 +791,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
####################
# - FlowKind: Func (w/Params if Constant)
####################
@bl_cache.cached_bl_property(
depends_on={
'value',
'sorted_sp_symbols',
'sorted_symbols',
'output_sym',
}
)
@bl_cache.cached_bl_property(depends_on={'output_sym'})
def lazy_func(self) -> ct.FuncFlow:
"""Returns a lazy value that computes the expression returned by `self.value`.
@ -806,42 +799,12 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, the returned lazy value function will be a simple excuse for `self.params` to pass the verbatim `self.value`.
"""
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if (
self.sorted_symbols and not ct.FlowSignal.check(self.value)
):
return ct.FuncFlow(
func=sp.lambdify(
self.sorted_sp_symbols,
self.output_sym.conform(self.value, strip_unit=True),
'jax',
),
func_args=list(self.sorted_symbols),
func_output=self.output_sym,
supports_jax=True,
)
case ct.FlowKind.Value | ct.FlowKind.Func if not self.sorted_symbols:
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'RangeFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)
case ct.FlowKind.Range if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -854,22 +817,15 @@ class ExprBLSocket(base.MaxwellSimSocket):
If `self.value` has unknown symbols (as indicated by `self.symbols`), then these will be passed into `ParamsFlow`, which will thus be parameterized (and require realization before use).
Otherwise, `self.value` is passed verbatim as the only `ParamsFlow.func_arg`.
"""
output_sym = self.output_sym
if output_sym is not None:
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
return ct.ParamsFlow(
arg_targets=list(self.sorted_symbols),
func_args=[sym.sp_symbol for sym in self.sorted_symbols],
symbols=set(self.sorted_symbols),
)
case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.value)
not ct.FlowSignal.check(self.value)
):
return ct.ParamsFlow(
arg_targets=[self.output_sym],
func_args=[self.value],
symbols=set(self.sorted_symbols),
)
case ct.FlowKind.Range if self.sorted_symbols:
@ -899,20 +855,14 @@ class ExprBLSocket(base.MaxwellSimSocket):
Otherwise, only the output name/size/mathtype/unit corresponding to the socket is passed along.
"""
output_sym = self.output_sym
if output_sym is not None:
if self.output_sym is not None:
match self.active_kind:
case ct.FlowKind.Value | ct.FlowKind.Func if self.sorted_symbols:
case ct.FlowKind.Value | ct.FlowKind.Func:
return ct.InfoFlow(
dims={sym: None for sym in self.sorted_symbols},
output=self.output_sym,
)
case ct.FlowKind.Value | ct.FlowKind.Func if (
not self.sorted_symbols and not ct.FlowSignal.check(self.lazy_range)
):
return ct.InfoFlow(output=self.output_sym)
case ct.FlowKind.Range if self.sorted_symbols:
msg = 'InfoFlow support not yet implemented for when self.sorted_symbols is not empty'
raise NotImplementedError(msg)

View File

@ -14,7 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import math
import bpy
import jax.numpy as jnp
import scipy as sc
import sympy.physics.units as spu
import tidy3d as td
@ -28,9 +31,15 @@ from .. import base
log = logger.get(__name__)
VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
FIXED_WL = 500 * spu.nm
_VAC_SPEED_OF_LIGHT = sc.constants.speed_of_light * spu.meter / spu.second
_FIXED_WL = 500 * spu.nm
FIXED_FREQ = spux.scale_to_unit_system(
spu.convert_to(
_VAC_SPEED_OF_LIGHT / _FIXED_WL,
spu.hertz,
),
ct.UNITS_TIDY3D,
)
class MaxwellMediumBLSocket(base.MaxwellSimSocket):
@ -49,24 +58,17 @@ class MaxwellMediumBLSocket(base.MaxwellSimSocket):
####################
@bl_cache.cached_bl_property(depends_on={'eps_rel', 'differentiable'})
def value(self) -> td.Medium:
freq = (
spu.convert_to(
VAC_SPEED_OF_LIGHT / FIXED_WL,
spu.hertz,
)
/ spu.hertz
)
eps_r_re = self.eps_rel[0]
conductivity = FIXED_FREQ * self.eps_rel[1] ## TODO: Um?
if self.differentiable:
return tdadj.JaxMedium.from_nk(
n=self.eps_rel[0],
k=self.eps_rel[1],
freq=freq,
return tdadj.JaxMedium(
permittivity_jax=jnp.array(eps_r_re, dtype=float),
conductivity_jax=jnp.array(conductivity, dtype=float),
)
return td.Medium.from_nk(
n=self.eps_rel[0],
k=self.eps_rel[1],
freq=freq,
return td.Medium(
permittivity=eps_r_re,
conductivity=conductivity,
)
@value.setter

View File

@ -263,6 +263,11 @@ class SimSymbol(pyd.BaseModel):
interval_inf_im: tuple[bool, bool] = (True, True)
interval_closed_im: tuple[bool, bool] = (False, False)
preview_value_z: int = 0
preview_value_q: tuple[int, int] = (0, 1)
preview_value_re: float = 0.0
preview_value_im: float = 0.0
####################
# - Core
####################