feat: mode solving / thesis demo ready

main
Sofus Albert Høgsbro Rose 2024-09-05 17:20:46 +02:00
parent 81a71b2c47
commit 4fc0528f6e
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
27 changed files with 984 additions and 147 deletions

View File

@ -6,6 +6,8 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
absl-py==2.1.0
# via chex
@ -138,6 +140,7 @@ numpy==1.24.3
# via opt-einsum
# via optax
# via orbax-checkpoint
# via pandas
# via patsy
# via pydantic-tensor
# via scipy

View File

@ -6,6 +6,8 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
absl-py==2.1.0
# via chex
@ -114,6 +116,7 @@ numpy==1.24.3
# via opt-einsum
# via optax
# via orbax-checkpoint
# via pandas
# via patsy
# via pydantic-tensor
# via scipy

View File

@ -57,3 +57,7 @@ class OperatorType(enum.StrEnum):
NodeReleaseUploadedTask = enum.auto()
NodeRunSimulation = enum.auto()
NodeReloadTrackedTask = enum.auto()
# Node: ModeSolver
NodeSolveModes = enum.auto()
NodeReleaseSolvedModes = enum.auto()

View File

@ -197,7 +197,7 @@ class FlowKind(enum.StrEnum):
):
"""Perform unit-system scaling per-`FlowKind`."""
match self:
case FlowKind.Value if isinstance(spux.SympyType):
case FlowKind.Value if isinstance(flow, spux.SympyType):
return spux.scale_to_unit_system(
flow,
unit_system,

View File

@ -17,7 +17,6 @@
import enum
import functools
import typing as typ
from types import MappingProxyType
import jax.numpy as jnp
import jaxtyping as jtyp
@ -294,8 +293,8 @@ class RangeFlow(pyd.BaseModel):
and array.is_sorted
):
return RangeFlow(
start=sp.S(array.values[0]),
stop=sp.S(array.values[-1]),
start=sp.S(array.values.item(0)),
stop=sp.S(array.values.item(-1)),
steps=len(array.values),
unit=array.unit,
)
@ -338,7 +337,7 @@ class RangeFlow(pyd.BaseModel):
@method_lru(maxsize=16)
def nearest_idx_of(self, value: spux.SympyType, require_sorted: bool = True) -> int:
raise NotImplementedError
return (value - self.start) / self.realize_step_size()
@functools.cached_property
def bound_fourier_transform(self):
@ -632,6 +631,9 @@ class RangeFlow(pyd.BaseModel):
symbols=self.symbols,
)
if isinstance(subscript, int) and self.scaling == ScalingMode.Lin:
return self.start + subscript * self.realize_step_size()
raise NotImplementedError
####################

View File

@ -125,7 +125,8 @@ class ParamsFlow(pyd.BaseModel):
return [
sp.lambdify(
self.all_sorted_sp_symbols,
target_sym.conform(func_arg, strip_unit=True),
spux.strip_unit_system(func_arg), ## TODO: THIS IS A WORKAROUND
# target_sym.conform(func_arg, strip_unit=True),
'jax',
)
for func_arg, target_sym in zip(self.func_args, arg_targets, strict=True)

View File

@ -49,6 +49,7 @@ SOCKET_COLORS = {
ST.MaxwellSimGrid: (0.5, 0.4, 0.3, 1.0), # Dark Gold
ST.MaxwellSimGridAxis: (0.4, 0.3, 0.25, 1.0), # Darkest Gold
ST.MaxwellSimDomain: (0.4, 0.3, 0.25, 1.0), # Darkest Gold
ST.MaxwellMode: (0.5, 0.3, 0.25, 1.0),
# Tidy3D
ST.Tidy3DCloudTask: (0.4, 0.3, 0.25, 1.0), # Darkest Gold
}

View File

@ -60,6 +60,8 @@ class SocketType(blender_type_enum.BlenderTypeEnum):
MaxwellSimGrid = enum.auto()
MaxwellSimGridAxis = enum.auto()
MaxwellMode = enum.auto()
# Tidy3D
Tidy3DCloudTask = enum.auto()

View File

@ -197,7 +197,7 @@ class FilterOperation(enum.StrEnum):
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
case FO.SliceIdx:
return [dim for dim in info.dims if not info.has_idx_labels(dim)]
return list(info.dims)
# Pin
case FO.PinLen1:

View File

@ -47,48 +47,107 @@ FS = ct.FlowSignal
####################
# - Monitor Labelling
####################
def valid_monitor_variants(monitor_type: str) -> list[str]:
"""Deduce the valid monitor variants."""
match monitor_type:
case (
'Field'
| 'FieldTime'
| 'FieldProjectionAngle'
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return ['E', 'H']
case 'Mode' | 'ModeSolver':
return ['E', 'H', 'n_complex', 'n_eff', 'k_eff']
case 'Flux' | 'FluxTime':
return ['flux']
case 'Permittivity':
return ['eps']
def monitor_variant_symbol(monitor_variant: str) -> sim_symbols.SimSymbol: # noqa: PLR0911
"""Deduce the correct output symbol based on the monitor variant."""
match monitor_variant:
case 'E':
return sim_symbols.field_e(spu.volt / spu.micrometer)
case 'H':
return sim_symbols.field_h(spu.ampere / spu.micrometer)
case 'n_complex':
return sim_symbols.rel_eps(None)
case 'n_eff':
return sim_symbols.rel_eps_re(None)
case 'k_eff':
return sim_symbols.rel_eps_im(None)
case 'flux':
return sim_symbols.flux(spu.watt)
case 'eps':
return sim_symbols.rel_eps(None)
def valid_monitor_attrs(
example_sim_data: td.SimulationData, monitor_name: str
example_sim_data: td.SimulationData, monitor_name: str, monitor_variant: str
) -> tuple[str, ...]:
"""Retrieve the valid attributes of `sim_data.monitor_data' from a valid `sim_data` of type `td.SimulationData`.
Parameters:
monitor_type: The name of the monitor type, with the 'Data' prefix removed.
example_sim_data: A representative simulation data from a batch.
In particular, the shape of its monitor data should be representative.
monitor_name: The name of the monitor, whose attributes should be checked for validity.
monitor_variant: The variant of the monitor's output to extract attributes of.
"""
monitor_data = example_sim_data.monitor_data[monitor_name]
monitor_type = monitor_data.type.removesuffix('Data')
match monitor_type:
case 'Field' | 'FieldTime' | 'Mode':
## TODO: flux, poynting, intensity
match (monitor_variant, monitor_type):
# E: Electric Field
case ('E', 'Field' | 'FieldTime' | 'Mode' | 'ModeSolver'):
return tuple(
[
field_component
for field_component in ['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']
for field_component in ['Ex', 'Ey', 'Ez']
if getattr(monitor_data, field_component, None) is not None
]
)
case 'Permittivity':
return ('eps_xx', 'eps_yy', 'eps_zz')
# H: Electric Field
case ('H', 'Field' | 'FieldTime' | 'Mode' | 'ModeSolver'):
return tuple(
[
field_component
for field_component in ['Hx', 'Hy', 'Hz']
if getattr(monitor_data, field_component, None) is not None
]
)
case 'Flux' | 'FluxTime':
# n_complex | n_eff | k_eff: Effective Complex IOR
case ('n_complex', 'Mode' | 'ModeSolver'):
return ('n_complex',)
case ('n_eff', 'Mode' | 'ModeSolver'):
return ('n_eff',)
case ('k_eff', 'Mode' | 'ModeSolver'):
return ('k_eff',)
case ('flux', 'Flux' | 'FluxTime'):
return ('flux',)
case (
'FieldProjectionAngle'
| 'FieldProjectionCartesian'
| 'FieldProjectionKSpace'
| 'Diffraction'
):
return (
'Er',
'Etheta',
'Ephi',
'Hr',
'Htheta',
'Hphi',
)
case ('eps', 'Permittivity'):
return ('eps_xx', 'eps_yy', 'eps_zz')
# case (
# 'FieldProjectionAngle'
# | 'FieldProjectionCartesian'
# | 'FieldProjectionKSpace'
# | 'Diffraction'
# ):
# return (
# 'Er',
# 'Etheta',
# 'Ephi',
# 'Hr',
# 'Htheta',
# 'Hphi',
# )
raise TypeError
@ -98,7 +157,8 @@ def valid_monitor_attrs(
####################
MONITOR_SYMBOLS: dict[str, sim_symbols.SimSymbol] = {
# Field Label
'EH*': sim_symbols.sim_axis_idx(None),
'sim_axis': sim_symbols.sim_axis_idx(None),
'mode_index': sim_symbols.mode_idx(None),
# Cartesian
'x': sim_symbols.space_x(spu.micrometer),
'y': sim_symbols.space_y(spu.micrometer),
@ -126,6 +186,8 @@ MONITOR_SYMBOLS: dict[str, sim_symbols.SimSymbol] = {
def _mk_idx_array(xarr: xarray.DataArray, axis: str) -> ct.RangeFlow | ct.ArrayFlow:
if axis == 'mode_index':
return [str(i) for i in xarr.get_index(axis).values]
return ct.RangeFlow.try_from_array(
ct.ArrayFlow(
jax_bytes=xarr.get_index(axis).values,
@ -135,121 +197,129 @@ def _mk_idx_array(xarr: xarray.DataArray, axis: str) -> ct.RangeFlow | ct.ArrayF
)
def output_symbol_by_type(monitor_type: str) -> sim_symbols.SimSymbol:
match monitor_type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
return MONITOR_SYMBOLS['field_e']
case 'FieldTime':
return MONITOR_SYMBOLS['field']
case 'Flux':
return MONITOR_SYMBOLS['flux']
case 'FluxTime':
return MONITOR_SYMBOLS['flux']
case 'FieldProjectionAngle':
return MONITOR_SYMBOLS['field']
case 'FieldProjectionKSpace':
return MONITOR_SYMBOLS['field']
case 'Diffraction':
return MONITOR_SYMBOLS['field']
return None
def _extract_info(
example_xarr: xarray.DataArray,
monitor_type: str,
monitor_variant: str,
monitor_attrs: tuple[str, ...],
batch_dims: dict[sim_symbols.SimSymbol, ct.RangeFlow | ct.ArrayFlow],
) -> ct.InfoFlow | None:
log.debug([monitor_type, monitor_attrs, batch_dims])
mk_idx_array = functools.partial(_mk_idx_array, example_xarr)
match monitor_type:
case 'Field' | 'FieldProjectionCartesian' | 'Permittivity' | 'Mode':
match (monitor_variant, monitor_type):
case ('E' | 'H', 'Field'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['x']: mk_idx_array('x'),
MONITOR_SYMBOLS['y']: mk_idx_array('y'),
MONITOR_SYMBOLS['z']: mk_idx_array('z'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=MONITOR_SYMBOLS['field_e'],
output=monitor_variant_symbol(monitor_variant),
)
case 'FieldTime':
case ('E' | 'H', 'FieldTime'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['x']: mk_idx_array('x'),
MONITOR_SYMBOLS['y']: mk_idx_array('y'),
MONITOR_SYMBOLS['z']: mk_idx_array('z'),
MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
output=MONITOR_SYMBOLS['field'],
output=monitor_variant_symbol(monitor_variant),
)
case 'Flux':
case ('E' | 'H', 'Mode' | 'ModeSolver'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['x']: mk_idx_array('x'),
MONITOR_SYMBOLS['y']: mk_idx_array('y'),
MONITOR_SYMBOLS['z']: mk_idx_array('z'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
MONITOR_SYMBOLS['mode_index']: mk_idx_array('mode_index'),
},
output=monitor_variant_symbol(monitor_variant),
)
case ('n_complex', 'Mode' | 'ModeSolver'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
MONITOR_SYMBOLS['mode_index']: mk_idx_array('mode_index'),
},
output=monitor_variant_symbol(monitor_variant),
)
case ('n_eff' | 'k_eff', 'Mode' | 'ModeSolver'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
MONITOR_SYMBOLS['mode_index']: mk_idx_array('mode_index'),
},
output=monitor_variant_symbol(monitor_variant),
)
case ('flux', 'Flux'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=MONITOR_SYMBOLS['flux'],
output=monitor_variant_symbol(monitor_variant),
)
case 'FluxTime':
case ('flux', 'FluxTime'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['t']: mk_idx_array('t'),
},
output=MONITOR_SYMBOLS['flux'],
output=monitor_variant_symbol(monitor_variant),
)
case 'FieldProjectionAngle':
case ('E' | 'H', 'FieldProjectionAngle'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['r']: mk_idx_array('r'),
MONITOR_SYMBOLS['theta']: mk_idx_array('theta'),
MONITOR_SYMBOLS['phi']: mk_idx_array('phi'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=MONITOR_SYMBOLS['field'],
output=monitor_variant_symbol(monitor_variant),
)
case 'FieldProjectionKSpace':
case ('E' | 'H', 'FieldProjectionKSpace'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['ux']: mk_idx_array('ux'),
MONITOR_SYMBOLS['uy']: mk_idx_array('uy'),
MONITOR_SYMBOLS['r']: mk_idx_array('r'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=MONITOR_SYMBOLS['field'],
output=monitor_variant_symbol(monitor_variant),
)
case 'Diffraction':
case ('E' | 'H', 'Diffraction'):
return ct.InfoFlow(
dims=batch_dims
| {
MONITOR_SYMBOLS['EH*']: monitor_attrs,
MONITOR_SYMBOLS['sim_axis']: monitor_attrs,
MONITOR_SYMBOLS['orders_x']: mk_idx_array('orders_x'),
MONITOR_SYMBOLS['orders_y']: mk_idx_array('orders_y'),
MONITOR_SYMBOLS['f']: mk_idx_array('f'),
},
output=MONITOR_SYMBOLS['field'],
output=monitor_variant_symbol(monitor_variant),
)
raise TypeError
@ -268,7 +338,9 @@ def extract_monitor_xarrs(
def extract_info(
monitor_datas: dict[RealizedSymsVals, typ.Any], monitor_attrs: tuple[str, ...]
monitor_datas: dict[RealizedSymsVals, typ.Any],
monitor_attrs: tuple[str, ...],
monitor_variant: str,
) -> dict[RealizedSymsVals, ct.InfoFlow]:
"""Extract an InfoFlow describing monitor data from a batch of simulations."""
# Extract Dimension from Batched Values
@ -318,6 +390,7 @@ def extract_info(
return _extract_info(
example_xarr,
example_monitor_data.type.removesuffix('Data'),
monitor_variant,
monitor_attrs,
batch_dims,
)
@ -489,6 +562,39 @@ class ExtractDataNode(base.MaxwellSimNode):
return []
monitor_variant: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_monitor_variants(),
cb_depends_on={'monitor_name', 'monitor_types'},
)
def search_monitor_variants(self) -> list[ct.BLEnumElement]:
"""Compute valid values for `self.monitor_attr`, for a dynamic `EnumProperty`.
Notes:
Should be reset (via `self.monitor_attr`) with (after) `self.sim_data_monitor_nametypes`, `self.monitor_data_attrs`, and (implicitly) `self.monitor_type`.
See `bl_cache.BLField` for more on dynamic `EnumProperty`.
Returns:
Valid `self.monitor_attr` in a format compatible with dynamic `EnumProperty`.
"""
if self.monitor_name is not None and self.monitor_types is not None:
monitor_type = self.monitor_types.get(self.monitor_name)
if monitor_type is not None:
return [
(
monitor_variant,
monitor_variant,
monitor_variant,
'',
i,
)
for i, monitor_variant in enumerate(
valid_monitor_variants(monitor_type)
)
]
return []
####################
# - Properties: Monitor Information
####################
@ -502,11 +608,19 @@ class ExtractDataNode(base.MaxwellSimNode):
}
return None
@bl_cache.cached_bl_property(depends_on={'example_sim_data', 'monitor_name'})
@bl_cache.cached_bl_property(
depends_on={'example_sim_data', 'monitor_name', 'monitor_variant'}
)
def valid_monitor_attrs(self) -> SimDataArrayInfo | None:
"""Valid attributes of the monitor, from the example sim data under the presumption that the entire batch shares the same attribute validity."""
if self.example_sim_data is not None and self.monitor_name is not None:
return valid_monitor_attrs(self.example_sim_data, self.monitor_name)
if (
self.example_sim_data is not None
and self.monitor_name is not None
and self.monitor_variant is not None
):
return valid_monitor_attrs(
self.example_sim_data, self.monitor_name, self.monitor_variant
)
return None
####################
@ -530,6 +644,7 @@ class ExtractDataNode(base.MaxwellSimNode):
col: UI target for drawing.
"""
col.prop(self, self.blfields['monitor_name'], text='')
col.prop(self, self.blfields['monitor_variant'], text='')
####################
# - FlowKind.Func
@ -538,21 +653,19 @@ class ExtractDataNode(base.MaxwellSimNode):
'Expr',
kind=FK.Func,
# Loaded
props={'monitor_datas', 'valid_monitor_attrs'},
props={'monitor_datas', 'valid_monitor_attrs', 'monitor_variant'},
)
def compute_extracted_data_func(self, props) -> ct.FuncFlow | FS:
"""Aggregates the selected monitor's data across all batched symbolic realizations, into a single FuncFlow."""
monitor_datas = props['monitor_datas']
valid_monitor_attrs = props['valid_monitor_attrs']
monitor_variant = props['monitor_variant']
if monitor_datas is not None and valid_monitor_attrs is not None:
monitor_datas_xarrs = extract_monitor_xarrs(
monitor_datas, valid_monitor_attrs
)
example_monitor_data = next(iter(monitor_datas.values()))
monitor_type = example_monitor_data.type.removesuffix('Data')
output_sym = output_symbol_by_type(monitor_type)
output_sym = monitor_variant_symbol(monitor_variant)
# Stack Inner Dimensions: components | *
## -> Each realization maps to exactly one xarray.
@ -648,7 +761,7 @@ class ExtractDataNode(base.MaxwellSimNode):
'Expr',
kind=FK.Info,
# Loaded
props={'monitor_datas', 'valid_monitor_attrs'},
props={'monitor_datas', 'valid_monitor_attrs', 'monitor_variant'},
)
def compute_extracted_data_info(self, props) -> ct.InfoFlow | FS:
"""Declare `Data:Info` by manually selecting appropriate axes, units, etc. for each monitor type.
@ -658,9 +771,10 @@ class ExtractDataNode(base.MaxwellSimNode):
"""
monitor_datas = props['monitor_datas']
valid_monitor_attrs = props['valid_monitor_attrs']
monitor_variant = props['monitor_variant']
if monitor_datas is not None and valid_monitor_attrs is not None:
return extract_info(monitor_datas, valid_monitor_attrs)
return extract_info(monitor_datas, valid_monitor_attrs, monitor_variant)
return FS.FlowPending
####################

View File

@ -137,7 +137,7 @@ class Tidy3DFileImporterNode(base.MaxwellSimNode):
# - UI
####################
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout):
col.prop(self, 'tidy3d_type', text='')
col.prop(self, self.blfields['tidy3d_type'], text='')
####################
# - Event Methods: Setup Output Socket

View File

@ -79,7 +79,7 @@ class FDTDSimNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
node_type = ct.NodeType.FDTDSim
bl_label = 'FDTD Simulation'
bl_label = 'Maxwell Simulation'
####################
# - Sockets
@ -158,7 +158,10 @@ class FDTDSimNode(base.MaxwellSimNode):
def min_wl(self) -> SimArrayInfo | None:
"""The smallest wavelength that occurs in the simulation."""
if self.sims is not None:
return {k: sim.wvl_mat_min * spu.um for k, sim in self.sims.items()}
return {
k: sim.wvl_mat_min * spu.um if sim.sources else None
for k, sim in self.sims.items()
}
return None
####################
@ -265,7 +268,7 @@ class FDTDSimNode(base.MaxwellSimNode):
for k, sim in self.sims.items(): # noqa: B007
try:
pass ## TODO: VERY slow, batch checking is infeasible
# sim.validate_pre_upload(source_required=True)
# sim.validate_pre_upload(source_required=False)
except td.exceptions.SetupError:
validity[k] = False
else:
@ -322,7 +325,14 @@ class FDTDSimNode(base.MaxwellSimNode):
('max t', spux.sp_to_str(self.time_range[syms_vals][1])),
('min f', spux.sp_to_str(self.freq_range[syms_vals][0])),
('max f', spux.sp_to_str(self.freq_range[syms_vals][1])),
('min λ', spux.sp_to_str(self.min_wl[syms_vals].n(2))),
(
'min λ',
spux.sp_to_str(
self.min_wl[syms_vals].n(2)
if self.min_wl[syms_vals] is not None
else None
),
),
]
labels += [

View File

@ -14,5 +14,367 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
BL_REGISTER = []
BL_NODES = {}
"""Implements `ModeSolverNode`."""
import enum
import typing as typ
import bpy
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import frozendict
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class ModePol(enum.StrEnum):
"""Polarization filter with which to arrange mode indices with."""
Arbitrary = enum.auto()
TE = enum.auto()
TM = enum.auto()
@property
def tidy3d_pol(self) -> str | None:
MP = ModePol
return {
MP.Arbitrary: None,
MP.TE: 'te',
MP.TM: 'tm',
}[self]
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
"""A human-readable UI-oriented name for a physical type."""
MP = ModePol
return {
MP.Arbitrary: 'Arbitrary',
MP.TE: 'TE',
MP.TM: 'TM',
}[value]
@staticmethod
def to_icon(_: typ.Self) -> str:
"""No icons."""
return ''
class RunModalSolve(bpy.types.Operator):
"""Run a modal solver on data provided to a `ModeSolverNode`."""
bl_idname = ct.OperatorType.NodeSolveModes
bl_label = 'Solve Modes'
bl_description = 'Solve for the eigenmodes of the simulation setup.'
@classmethod
def poll(cls, context):
"""Allow running when there are runnable tasks."""
return (
# Check FDTDSolverNode is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.ModeSolver
# Check Task is Runnable
and not context.node.solved
and context.node.mode_solver is not None
)
def execute(self, context):
"""Run all uploaded, runnable tasks."""
node = context.node
node.trigger_event(ct.FlowEvent.EnableLock)
node.locked = False
try:
node.mode_solver.solve()
except: # noqa: E722
node.trigger_event(ct.FlowEvent.DisableLock)
self.report({'ERROR'}, 'Modal solution failed. Please check logs.')
return {'FINISHED'}
else:
node.solved = True
return {'FINISHED'}
class ReleaseSolvedModes(bpy.types.Operator):
"""Run a modal solver on data provided to a `ModeSolverNode`."""
bl_idname = ct.OperatorType.NodeReleaseSolvedModes
bl_label = 'Release Solved Modes'
bl_description = 'Release the solved eigenmodes.'
@classmethod
def poll(cls, context):
"""Allow running when there are runnable tasks."""
return (
# Check FDTDSolverNode is Accessible
hasattr(context, 'node')
and hasattr(context.node, 'node_type')
and context.node.node_type == ct.NodeType.ModeSolver
# Check Task is Runnable
and context.node.solved
and context.node.mode_solver is not None
)
def execute(self, context):
"""Reset the modal solver."""
node = context.node
node.solved = False
node.trigger_event(ct.FlowEvent.DisableLock)
node.mode_solver = bl_cache.Signal.InvalidateCache
return {'FINISHED'}
class ModeSolverNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
node_type = ct.NodeType.ModeSolver
bl_label = 'Mode Solver'
####################
# - Sockets
####################
input_sockets: typ.ClassVar = {
'Sim': sockets.MaxwellFDTDSimSocketDef(active_kind=FK.Value),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
mathtype=MT.Real,
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([0, 0, 0]),
),
'Plane Size': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec2,
mathtype=MT.Real,
physical_type=PT.Length,
default_value=sp.ImmutableMatrix([1, 1]),
),
'Freqs': sockets.ExprSocketDef(
active_kind=FK.Range,
physical_type=PT.Freq,
default_unit=spux.THz,
default_min=374.7406, ## 800nm
default_max=1498.962, ## 200nm
default_steps=100,
),
'# Modes': sockets.ExprSocketDef(
mathtype=MT.Integer,
default_value=3,
abs_min=1,
),
'Guess neff': sockets.ExprSocketDef(
mathtype=MT.Real,
default_value=2.0,
),
## TODO: Spherical, bend, mode tracking, group index, extra PML
}
output_sockets: typ.ClassVar = {
'Solved Modes': sockets.MaxwellModeSocketDef(),
'Sim Data': sockets.MaxwellFDTDSimDataSocketDef(),
}
####################
# - Properties: UI
####################
injection_axis: ct.SimSpaceAxis = bl_cache.BLField(ct.SimSpaceAxis.X)
injection_direction: ct.SimAxisDir = bl_cache.BLField(ct.SimAxisDir.Plus)
mode_pol: ModePol = bl_cache.BLField(ModePol.Arbitrary)
solved: bool = bl_cache.BLField(False)
####################
# - Properties: Mode Solver
####################
@events.on_value_changed(
socket_name={
'Sim': {FK.Func, FK.Params},
'Center': {FK.Func, FK.Params},
'Plane Size': {FK.Func, FK.Params},
'Freqs': {FK.Func, FK.Params},
'# Modes': {FK.Func, FK.Params},
'Guess neff': {FK.Func, FK.Params},
},
)
def on_input_changed(self) -> None:
"""Recomputes the mode solver in response to inputs."""
self.mode_solver = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property(
depends_on={'injection_axis', 'injection_direction', 'mode_pol'}
)
def mode_solver(self) -> td.plugins.mode.ModeSolver | None:
sim = self._compute_input('Sim', kind=FK.Value)
center = events.realize_known(
frozendict(
{
kind: self._compute_input(
'Center', kind=kind, unit_system=ct.UNITS_TIDY3D
)
for kind in [FK.Func, FK.Params]
}
),
freeze=True,
)
size_2d = events.realize_known(
frozendict(
{
kind: self._compute_input(
'Plane Size', kind=kind, unit_system=ct.UNITS_TIDY3D
)
for kind in [FK.Func, FK.Params]
}
),
freeze=True,
)
freqs = events.realize_known(
frozendict(
{
kind: self._compute_input(
'Freqs', kind=kind, unit_system=ct.UNITS_TIDY3D
)
for kind in [FK.Func, FK.Params]
}
),
freeze=True,
)
num_modes = events.realize_known(
frozendict(
{
kind: self._compute_input('# Modes', kind=kind)
for kind in [FK.Func, FK.Params]
}
),
freeze=True,
)
guess_neff = events.realize_known(
frozendict(
{
kind: self._compute_input('Guess neff', kind=kind)
for kind in [FK.Func, FK.Params]
}
),
freeze=True,
)
if not FS.check(sim) and all(
flow is not None
for flow in [
center,
size_2d,
num_modes,
guess_neff,
freqs,
]
):
mode_pol = self.mode_pol
injection_direction = self.injection_direction
_size_2d = sp.flatten(size_2d)
size = {
ct.SimSpaceAxis.X: (0, *_size_2d),
ct.SimSpaceAxis.Y: (_size_2d[0], 0, _size_2d[1]),
ct.SimSpaceAxis.Z: (*_size_2d, 0),
}[self.injection_axis]
return td.plugins.mode.ModeSolver(
simulation=sim,
plane=td.Box(center=sp.flatten(center), size=size),
mode_spec=td.ModeSpec(
num_modes=num_modes,
target_neff=guess_neff,
filter_pol=mode_pol.tidy3d_pol,
),
freqs=freqs,
direction=injection_direction.plus_or_minus,
)
return None
####################
# - UI
####################
def draw_operators(self, _: bpy.types.Context, layout: bpy.types.UILayout):
row = layout.row(align=True)
row.operator(
ct.OperatorType.NodeSolveModes,
text='Solve',
)
if self.solved:
row.operator(
ct.OperatorType.NodeReleaseSolvedModes,
icon='LOOP_BACK',
text='',
)
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout):
"""Present choices of injection axis, direction, and polarization filter."""
row = layout.row(align=True)
row.alignment = 'CENTER'
row.label(text='Injection')
layout.prop(self, self.blfields['injection_axis'], expand=True)
layout.prop(self, self.blfields['injection_direction'], expand=True)
row = layout.row(align=True)
row.alignment = 'CENTER'
row.label(text='Pol Filter')
layout.prop(self, self.blfields['mode_pol'], text='')
####################
# - FlowKind.Value: Solved Modes
####################
@events.computes_output_socket(
'Solved Modes',
kind=FK.Value,
# Loaded
props={'mode_solver', 'solved'},
)
def compute_solved_modes(self, props) -> td.ModeSolverData | FS:
mode_solver = props['mode_solver']
solved = props['solved']
if mode_solver is not None and solved:
return mode_solver.data
return FS.FlowPending
####################
# - FlowKind.Value: Sim Data
####################
@events.computes_output_socket(
'Sim Data',
kind=FK.Value,
# Loaded
props={'mode_solver', 'solved'},
)
def compute_sim_data_value(self, props) -> td.SimulationData | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
mode_solver = props['mode_solver']
solved = props['solved']
if mode_solver is not None and solved:
return mode_solver.sim_data
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = [
RunModalSolve,
ReleaseSolvedModes,
ModeSolverNode,
]
BL_NODES = {ct.NodeType.ModeSolver: (ct.NodeCategory.MAXWELLSIM_SOLVERS)}

View File

@ -0,0 +1,128 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Implements `ModeSolverNode`."""
import typing as typ
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import sympy_extra as spux
from blender_maxwell.utils.frozendict import frozendict
from ... import contracts as ct
from ... import sockets
from .. import base, events
log = logger.get(__name__)
FK = ct.FlowKind
FS = ct.FlowSignal
MT = spux.MathType
PT = spux.PhysicalType
class ModeSolverNode(base.MaxwellSimNode):
"""Definition of a complete FDTD simulation, including boundary conditions, domain, sources, structures, monitors, and other configuration."""
node_type = ct.NodeType.ModeSolver
bl_label = 'Mode Solver'
use_sim_node_name = True
####################
# - Sockets
####################
input_sockets: typ.ClassVar = {
'Solved Modes': sockets.MaxwellModeSocketDef(),
'Inj Mode': sockets.ExprSocketDef(
mathtype=MT.Integer,
abs_min=0,
default_value=0,
),
'Guess neff': sockets.ExprSocketDef(
mathtype=MT.Integer,
abs_min=1,
),
## TODO: Spherical, bend, mode tracking, group index
}
output_sockets: typ.ClassVar = {
'Angled Source': sockets.MaxwellSourceSocketDef(active_kind=FK.Value),
}
####################
# - Properties: UI
####################
injection_axis: ct.SimSpaceAxis = bl_cache.BLField(ct.SimSpaceAxis.X)
injection_direction: ct.SimAxisDir = bl_cache.BLField(ct.SimAxisDir.Plus)
####################
# - FlowKind.Value
####################
@events.computes_output_socket(
'Sim',
kind=FK.Value,
# Loaded
props={'active_socket_set'},
outscks_kinds={'Sim': {FK.Func, FK.Params}},
all_loose_input_sockets=True,
loose_input_sockets_kind={FK.Func, FK.Params},
)
def compute_value(
self, props, loose_input_sockets, output_sockets
) -> td.Simulation | FS:
"""Compute the particular value of the simulation domain from strictly non-symbolic inputs."""
func = output_sockets['Sim'][FK.Func]
params = output_sockets['Sim'][FK.Params]
has_func = not FS.check(func)
has_params = not FS.check(params)
active_socket_set = props['active_socket_set']
if has_func and has_params and active_socket_set == 'Single':
symbol_values = {
sym: events.realize_known(loose_input_sockets[sym.name])
for sym in params.sorted_symbols
}
return func.realize(
params,
symbol_values=frozendict(
{
sym: events.realize_known(loose_input_sockets[sym.name])
for sym in params.sorted_symbols
}
),
).updated_copy(
attrs=dict(
ct.SimMetadata(
realizations=ct.SimRealizations(
syms=tuple(symbol_values.keys()),
vals=tuple(symbol_values.values()),
)
)
)
)
return FS.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = [
ModeSolverNode,
]
BL_NODES = {ct.NodeType.ModeSolver: (ct.NodeCategory.MAXWELLSIM_SIMS)}

View File

@ -48,7 +48,7 @@ class BoxStructureNode(base.MaxwellSimNode):
# - Sockets
####################
input_sockets: typ.ClassVar = {
'Medium': sockets.MaxwellMediumSocketDef(),
'Medium': sockets.MaxwellMediumSocketDef(active_kind=FK.Func),
'Center': sockets.ExprSocketDef(
size=spux.NumberSize1D.Vec3,
default_unit=spu.micrometer,

View File

@ -19,6 +19,7 @@
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
import sympy.physics.units as spu
@ -208,46 +209,84 @@ class WaveConstantNode(base.MaxwellSimNode):
####################
# - FlowKind.Func
####################
# @events.computes_output_socket(
# 'WL',
# kind=FK.Func,
# # Loaded
# inscks_kinds={'WL': FK.Func, 'Freq': FK.Func},
# input_sockets_optional={'WL', 'Freq'},
# )
# def compute_wl_func(self, input_sockets: dict) -> ct.FuncFlow | FS:
# """Compute a single wavelength value from either wavelength/frequency."""
# wl = input_sockets['WL']
# has_wl = not FS.check(wl)
# if has_wl:
# return wl
@events.computes_output_socket(
'WL',
kind=FK.Func,
# Loaded
props={'active_socket_set', 'use_range'},
inscks_kinds={'WL': FK.Func, 'Freq': FK.Func},
input_sockets_optional={'WL', 'Freq'},
)
def compute_wl_func(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Compute a single wavelength value from either wavelength/frequency."""
active_socket_set = props['active_socket_set']
use_range = props['use_range']
# freq = input_sockets['Freq']
# has_freq = not FS.check(freq)
# if has_freq:
# return wl.compose_within(
# return spu.convert_to(
# sci_constants.vac_speed_of_light / input_sockets['Freq'], spu.um
# )
wl = input_sockets['WL']
freq = input_sockets['Freq']
# return FS.FlowPending
match active_socket_set:
case 'Wavelength' if not FS.check(wl):
return wl
# @events.computes_output_socket(
# 'Freq',
# kind=FK.Value,
# # Loaded
# inscks_kinds={'WL': FK.Value, 'Freq': FK.Value},
# input_sockets_optional={'WL', 'Freq'},
# )
# def compute_freq_value(self, input_sockets: dict) -> sp.Expr:
# """Compute a single frequency value from either wavelength/frequency."""
# has_freq = not FS.check(input_sockets['Freq'])
# if has_freq:
# return input_sockets['Freq']
case 'Frequency' if not FS.check(freq):
a = MT.Real.sp_symbol_a
scaling_expr = spux.scale_to_unit(
sci_constants.vac_speed_of_light / (a * freq.func_output.unit),
spu.um,
)
scaler = sp.lambdify(a, scaling_expr, 'jax')
# return spu.convert_to(
# sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz
# )
if use_range:
return freq.compose_within(
lambda _freq: jnp.flip(scaler(_freq)),
enclosing_func_output=sim_symbols.wl(spu.um),
)
return freq.compose_within(
lambda _freq: scaler(_freq),
enclosing_func_output=sim_symbols.wl(spu.um),
)
return FS.FlowPending
@events.computes_output_socket(
'Freq',
kind=FK.Func,
# Loaded
props={'active_socket_set', 'use_range'},
inscks_kinds={'WL': FK.Func, 'Freq': FK.Func},
input_sockets_optional={'WL', 'Freq'},
)
def compute_freq_func(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Compute a single wavelength value from either wavelength/frequency."""
active_socket_set = props['active_socket_set']
use_range = props['use_range']
wl = input_sockets['WL']
freq = input_sockets['Freq']
match active_socket_set:
case 'Wavelength' if not FS.check(wl):
a = MT.Real.sp_symbol_a
scaling_expr = spux.scale_to_unit(
sci_constants.vac_speed_of_light / (a * wl.func_output.unit),
spux.THz,
)
scaler = sp.lambdify(a, scaling_expr, 'jax')
if use_range:
return wl.compose_within(
lambda _wl: jnp.flip(scaler(_wl)),
enclosing_func_output=sim_symbols.freq(spux.THz),
)
return wl.compose_within(
lambda _wl: scaler(_wl),
enclosing_func_output=sim_symbols.freq(spux.THz),
)
case 'Frequency' if not FS.check(freq):
return freq
return FS.FlowPending
####################
# - FlowKind.Info
@ -272,6 +311,55 @@ class WaveConstantNode(base.MaxwellSimNode):
output=sim_symbols.freq(spux.THz),
)
####################
# - FlowKind.Params
####################
@events.computes_output_socket(
'WL',
kind=FK.Params,
# Loaded
props={'active_socket_set'},
inscks_kinds={'WL': FK.Params, 'Freq': FK.Params},
input_sockets_optional={'WL', 'Freq'},
)
def compute_freq_params(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Compute a single wavelength value from either wavelength/frequency."""
wl = input_sockets['WL']
freq = input_sockets['Freq']
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Wavelength' if not FS.check(wl):
return wl
case 'Frequency' if not FS.check(freq):
return freq
return FS.FlowPending
@events.computes_output_socket(
'Freq',
kind=FK.Params,
# Loaded
props={'active_socket_set'},
inscks_kinds={'WL': FK.Params, 'Freq': FK.Params},
input_sockets_optional={'WL', 'Freq'},
)
def compute_freq_params(self, props, input_sockets) -> ct.FuncFlow | FS:
"""Compute a single wavelength value from either wavelength/frequency."""
wl = input_sockets['WL']
freq = input_sockets['Freq']
active_socket_set = props['active_socket_set']
match active_socket_set:
case 'Wavelength' if not FS.check(wl):
return wl
case 'Frequency' if not FS.check(freq):
return freq
return FS.FlowPending
####################
# - Blender Registration

View File

@ -832,6 +832,8 @@ class MaxwellSimSocket(bpy.types.NodeSocket, bl_instance.BLInstance):
if FS.check_single(flow, FS.FlowPending) and not self.flow_error:
bpy.app.timers.register(self.declare_flow_error)
# elif self.flow_error:
# bpy.app.timers.register(self.clear_flow_error)
return flow

View File

@ -987,12 +987,6 @@ class ExprBLSocket(base.MaxwellSimSocket):
func_output=self.output_sym,
supports_jax=True,
)
return ct.FuncFlow(
func=lambda v: v,
func_args=[self.output_sym],
func_output=self.output_sym,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@ -1709,9 +1703,9 @@ class ExprSocketDef(base.SocketDef):
bl_socket.size = self.size
bl_socket.mathtype = self.mathtype
bl_socket.physical_type = self.physical_type
bl_socket.active_unit = bl_cache.Signal.ResetEnumItems
bl_socket.unit = bl_cache.Signal.InvalidateCache
bl_socket.unit_factor = bl_cache.Signal.InvalidateCache
# bl_socket.active_unit = bl_cache.Signal.InvalidateCache
# bl_socket.unit = bl_cache.Signal.InvalidateCache
# bl_socket.unit_factor = bl_cache.Signal.InvalidateCache
bl_socket.symbols = self.default_symbols
# Domain

View File

@ -21,6 +21,7 @@ from . import (
fdtd_sim_data,
medium,
medium_non_linearity,
mode,
monitor,
monitor_data,
sim_domain,
@ -47,6 +48,7 @@ MaxwellSimGridAxisSocketDef = sim_grid_axis.MaxwellSimGridAxisSocketDef
MaxwellSourceSocketDef = source.MaxwellSourceSocketDef
MaxwellStructureSocketDef = structure.MaxwellStructureSocketDef
MaxwellTemporalShapeSocketDef = temporal_shape.MaxwellTemporalShapeSocketDef
MaxwellModeSocketDef = mode.MaxwellModeSocketDef
BL_REGISTER = [
@ -56,6 +58,7 @@ BL_REGISTER = [
*fdtd_sim_data.BL_REGISTER,
*medium.BL_REGISTER,
*medium_non_linearity.BL_REGISTER,
*mode.BL_REGISTER,
*monitor.BL_REGISTER,
*monitor_data.BL_REGISTER,
*sim_domain.BL_REGISTER,

View File

@ -0,0 +1,46 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as typ
from ... import contracts as ct
from .. import base
class MaxwellModeBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.MaxwellMode
bl_label = 'Maxwell FDTD Simulation'
####################
# - Socket Configuration
####################
class MaxwellModeSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.MaxwellMode
def init(self, bl_socket: MaxwellModeBLSocket) -> None:
pass
def local_compare(self, _: MaxwellModeBLSocket) -> None:
return True
####################
# - Blender Registration
####################
BL_REGISTER = [
MaxwellModeBLSocket,
]

View File

@ -14,14 +14,35 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from blender_maxwell.utils import bl_cache
from ... import contracts as ct
from .. import base
FK = ct.FlowKind
FS = ct.FlowSignal
class MaxwellMonitorBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.MaxwellMonitor
bl_label = 'Maxwell Monitor'
@bl_cache.cached_bl_property()
def array(self) -> list:
return []
@bl_cache.cached_bl_property()
def lazy_func(self) -> ct.FuncFlow | FS:
if self.active_kind is FK.Array:
return ct.FuncFlow(func=list)
return FS.NoFlow
@bl_cache.cached_bl_property()
def params(self) -> ct.ParamsFlow:
if self.active_kind is FK.Array:
return ct.ParamsFlow()
return FS.NoFlow
####################
# - Socket Configuration

View File

@ -14,14 +14,35 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from blender_maxwell.utils import bl_cache
from ... import contracts as ct
from .. import base
FK = ct.FlowKind
FS = ct.FlowSignal
class MaxwellSourceBLSocket(base.MaxwellSimSocket):
socket_type = ct.SocketType.MaxwellSource
bl_label = 'Maxwell Source'
@bl_cache.cached_bl_property()
def array(self) -> list:
return []
@bl_cache.cached_bl_property()
def lazy_func(self) -> ct.FuncFlow | FS:
if self.active_kind is FK.Array:
return ct.FuncFlow(func=list)
return FS.NoFlow
@bl_cache.cached_bl_property()
def params(self) -> ct.ParamsFlow:
if self.active_kind is FK.Array:
return ct.ParamsFlow()
return FS.NoFlow
####################
# - Socket Configuration

View File

@ -163,7 +163,7 @@ def pinned_labels(pinned_data) -> str:
'\n'
+ ', '.join(
[
f'{sym.name_pretty}:' + _parse_val(val)
f'{sym.name_pretty}=' + _parse_val(val)
for sym, val in pinned_data.items()
]
)
@ -257,7 +257,7 @@ def plot_heatmap_2d(data, ax: mpl_ax.Axis) -> None:
x_sym, y_sym, c_sym, pinned = list(data.keys())
heatmap = ax.imshow(data[c_sym], aspect='equal', interpolation='none')
# ax.figure.colorbar(heatmap, ax=ax)
ax.figure.colorbar(heatmap, ax=ax)
ax.set_title(
f'({x_sym.name_pretty}, {y_sym.name_pretty}) → {c_sym.plot_label} {pinned_labels(data[pinned])}'

View File

@ -37,6 +37,8 @@ from .common import (
flux,
freq,
idx,
mode_idx,
rel_eps,
rel_eps_im,
rel_eps_re,
sim_axis_idx,
@ -60,6 +62,8 @@ from .utils import (
__all__ = [
'CommonSimSymbol',
'idx',
'mode_idx',
'rel_eps',
'rel_eps_im',
'rel_eps_re',
'sim_axis_idx',

View File

@ -46,6 +46,7 @@ class CommonSimSymbol(enum.StrEnum):
Index = enum.auto()
SimAxisIdx = enum.auto()
ModeIdx = enum.auto()
# Space|Time
SpaceX = enum.auto()
@ -88,6 +89,7 @@ class CommonSimSymbol(enum.StrEnum):
DiffOrderX = enum.auto()
DiffOrderY = enum.auto()
RelEps = enum.auto()
RelEpsRe = enum.auto()
RelEpsIm = enum.auto()
@ -128,6 +130,7 @@ class CommonSimSymbol(enum.StrEnum):
return {
CSS.Index: SSN.LowerI,
CSS.SimAxisIdx: SSN.SimAxisIdx,
CSS.ModeIdx: SSN.ModeIdx,
# Space|Time
CSS.SpaceX: SSN.LowerX,
CSS.SpaceY: SSN.LowerY,
@ -156,6 +159,7 @@ class CommonSimSymbol(enum.StrEnum):
CSS.Flux: SSN.Flux,
CSS.DiffOrderX: SSN.DiffOrderX,
CSS.DiffOrderY: SSN.DiffOrderY,
CSS.RelEps: SSN.Perm,
CSS.RelEpsRe: SSN.RelEpsRe,
CSS.RelEpsIm: SSN.RelEpsIm,
}[self]
@ -201,6 +205,11 @@ class CommonSimSymbol(enum.StrEnum):
mathtype=spux.MathType.Integer,
domain=spux.BlessedSet(sp.FiniteSet(0, 1, 2)),
),
CSS.ModeIdx: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Integer,
domain=spux.BlessedSet(sp.Naturals0),
),
# Space|Time
CSS.SpaceX: sym_space,
CSS.SpaceY: sym_space,
@ -275,6 +284,11 @@ class CommonSimSymbol(enum.StrEnum):
mathtype=spux.MathType.Integer,
domain=spux.BlessedSet(sp.Integers),
),
CSS.RelEps: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Complex,
domain=spux.BlessedSet(sp.Complexes),
),
CSS.RelEpsRe: SimSymbol(
sym_name=self.name,
mathtype=spux.MathType.Real,
@ -293,6 +307,8 @@ class CommonSimSymbol(enum.StrEnum):
####################
idx = CommonSimSymbol.Index.sim_symbol
sim_axis_idx = CommonSimSymbol.SimAxisIdx.sim_symbol
mode_idx = CommonSimSymbol.ModeIdx.sim_symbol
t = CommonSimSymbol.Time.sim_symbol
wl = CommonSimSymbol.Wavelength.sim_symbol
freq = CommonSimSymbol.Frequency.sim_symbol
@ -323,5 +339,6 @@ flux = CommonSimSymbol.Flux.sim_symbol
diff_order_x = CommonSimSymbol.DiffOrderX.sim_symbol
diff_order_y = CommonSimSymbol.DiffOrderY.sim_symbol
rel_eps = CommonSimSymbol.RelEps.sim_symbol
rel_eps_re = CommonSimSymbol.RelEpsRe.sim_symbol
rel_eps_im = CommonSimSymbol.RelEpsIm.sim_symbol

View File

@ -95,6 +95,7 @@ class SimSymbolName(enum.StrEnum):
RelEpsIm = enum.auto()
SimAxisIdx = enum.auto()
ModeIdx = enum.auto()
####################
# - UI
@ -177,6 +178,7 @@ class SimSymbolName(enum.StrEnum):
SSN.RelEpsRe: 'eps_r_re',
SSN.RelEpsIm: 'eps_r_im',
SSN.SimAxisIdx: '[xyz]',
SSN.ModeIdx: 'mode',
}
)[self]

View File

@ -675,6 +675,13 @@ class SimSymbol(pyd.BaseModel):
else:
res = sp.S(obj)
## TODO: THIS IS A WORKAROUND
## TODO: Only a plain MatrixSymbol is detected, this is **not enough**.
## -- We also need to handle expressions containing MatrixSymbols.
## -- We do this before adding units to catch some cases.
## -- This is rather fragile right now.
is_matrix = isinstance(res, sp.MatrixBase | sp.MatrixSymbol)
# Unit Conversion
match (spux.uses_units(res), self.unit is not None):
case (True, True):
@ -694,10 +701,12 @@ class SimSymbol(pyd.BaseModel):
self.depths == ()
and (self.rows > 1 or self.cols > 1)
and not isinstance(res, sp.MatrixBase | sp.MatrixSymbol)
and not is_matrix
):
res = sp.ImmutableMatrix.ones(self.rows, self.cols).applyfunc(
lambda el: 5 * el
)
res = res * sp.ImmutableMatrix.ones(self.rows, self.cols)
# res = sp.ImmutableMatrix.ones(self.rows, self.cols).applyfunc(
# lambda el: res * el
# )
return res