feat: implemented data loading w/new math ops

We implemented a node to load various kinds of data, notably `.npy`,
`.txt`, `.txt.gz`, and `.csv`. The `DataFileImporterNode` really should
expose some settings for setting name/mathtype/physical type/unit of
each unit, and/or treating a column from 2D data as index coordinates.
But the nuances of doing this in a manner general enough to deal with
!=2D data was a lot, and we needed similar abilities in the general math
system anyway.

So, we delved back into the `FilterMathNode` and a little into the
`TransformMathNode`. Fundamentally, a few difficult operations came out
of this:

- Filter:SliceIdx: Slice an array using the usual syntax, as baked into the
  function.
- Filter:PinIdx: Pin an axis by its actual index.
- Filter:SetDim: Set the `InfoFlow` index coordinates of an axis to a specific,
  loose-socket provided 1D array, and use a common symbol to set the
  name+physical type (and allow specifying an appropriate unit).
- Transform:IntDimToComplex: Fold a last length-2 integer-indexed axis
  into a real output type, which removes the dimension and produces a
  complex output type. Essentially, this is equivalent to folding it as
  a vector and treating the `R^2` numbers as real/imaginary, except this
  is more explicit.

By combining all of these, we managed to process and constrain the medium data to
be a well-suited, unit-aware (**though not on the output (yet)**) `wl->C` tensor.
In particular, the slicing is nice for avoiding discontinuities.

Workflow-wise, we'll see how important these are / what else we might
want. Also, it turns out Blender's text editor is really quite nice for
light data-text viewing.
main
Sofus Albert Høgsbro Rose 2024-05-19 18:04:58 +02:00
parent a66a28da27
commit 0f2f494868
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
10 changed files with 976 additions and 173 deletions

View File

@ -122,7 +122,6 @@ class ArrayFlow:
if self.unit is not None
else rescale_func(a * self.unit)
)
log.critical([self.unit, new_unit, rescale_expr])
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
values = _rescale_func(self.values)
@ -132,3 +131,13 @@ class ArrayFlow:
unit=new_unit,
is_sorted=self.is_sorted,
)
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice):
return ArrayFlow(
values=self.values[subscript],
unit=self.unit,
is_sorted=self.is_sorted,
)
raise NotImplementedError

View File

@ -39,6 +39,16 @@ class InfoFlow:
default_factory=dict
) ## TODO: Rename to dim_idxs
@functools.cached_property
def dim_has_coords(self) -> dict[str, int]:
return {
dim_name: not (
isinstance(dim_idx, LazyArrayRangeFlow)
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
)
for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property
def dim_lens(self) -> dict[str, int]:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@ -99,9 +109,29 @@ class InfoFlow:
####################
# - Methods
####################
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx={
_dim_name: (
dim_idx
if _dim_name != dim_name
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
)
for _dim_name, dim_idx in self.dim_idx.items()
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def replace_dim(
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow]
) -> typ.Self:
"""Replace a dimension (and its indexing) with a new name and index array/range."""
return InfoFlow(
# Dimensions
dim_names=[
@ -122,6 +152,7 @@ class InfoFlow:
)
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
"""Replace several dimensional indices with new index arrays/ranges."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
@ -156,7 +187,7 @@ class InfoFlow:
)
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
"""Delete a dimension."""
"""Swap the position of two dimensions."""
# Compute Swapped Dimension Name List
def name_swapper(dim_name):
@ -181,7 +212,7 @@ class InfoFlow:
)
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
"""Set the MathType of a particular output name."""
"""Set the MathType of the output."""
return InfoFlow(
dim_names=self.dim_names,
dim_idx=self.dim_idx,
@ -198,6 +229,7 @@ class InfoFlow:
collapsed_mathtype: spux.MathType,
collapsed_unit: spux.Unit,
) -> typ.Self:
"""Replace the (scalar) output with the given corrected values."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names,

View File

@ -418,7 +418,11 @@ class LazyArrayRangeFlow:
self,
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
) -> ArrayFlow | LazyValueFuncFlow:
return (self.realize_stop() - self.realize_start()) / self.steps
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
return int(raw_step_size)
return raw_step_size
def realize(
self,
@ -463,3 +467,28 @@ class LazyArrayRangeFlow:
@functools.cached_property
def realize_array(self) -> ArrayFlow:
return self.realize()
def __getitem__(self, subscript: slice):
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
# Parse Slice
start = subscript.start if subscript.start is not None else 0
stop = subscript.stop if subscript.stop is not None else self.steps
step = subscript.step if subscript.step is not None else 1
slice_steps = (stop - start + step - 1) // step
# Compute New Start/Stop
step_size = self.realize_step_size()
new_start = step_size * start
new_stop = new_start + step_size * slice_steps
return LazyArrayRangeFlow(
start=sp.S(new_start),
stop=sp.S(new_stop),
steps=slice_steps,
scaling=self.scaling,
unit=self.unit,
symbols=self.symbols,
)
raise NotImplementedError

View File

@ -58,9 +58,11 @@ class ParamsFlow:
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
return [
(
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
if arg not in symbol_values
else symbol_values[arg]
)
for arg in self.func_args
]

View File

@ -20,9 +20,11 @@ import enum
import typing as typ
import bpy
import jax.lax as jlax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import bl_cache, logger, sim_symbols
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
@ -43,64 +45,201 @@ class FilterOperation(enum.StrEnum):
Swap: Swap the positions of two dimensions.
"""
# Dimensions
# Slice
SliceIdx = enum.auto()
# Pin
PinLen1 = enum.auto()
Pin = enum.auto()
PinIdx = enum.auto()
# Reinterpret
Swap = enum.auto()
SetDim = enum.auto()
# Fold
DimToVec = enum.auto()
DimsToMat = enum.auto()
####################
# - UI
####################
@staticmethod
def to_name(value: typ.Self) -> str:
FO = FilterOperation
return {
# Dimensions
# Slice
FO.SliceIdx: 'a[...]',
# Pin
FO.PinLen1: 'pinₐ =1',
FO.Pin: 'pinₐ ≈v',
FO.PinIdx: 'pinₐ =a[v]',
# Reinterpret
FO.Swap: 'a₁ ↔ a₂',
# Interpret
FO.DimToVec: '→ Vector',
FO.DimsToMat: '→ Matrix',
FO.SetDim: 'setₐ =v',
}[value]
@staticmethod
def to_icon(value: typ.Self) -> str:
return ''
def are_dims_valid(self, dim_0: int | None, dim_1: int | None):
return not (
(
dim_0 is None
and self
in [FilterOperation.PinLen1, FilterOperation.Pin, FilterOperation.Swap]
)
or (dim_1 is None and self == FilterOperation.Swap)
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
FO = FilterOperation
return (
str(self),
FO.to_name(self),
FO.to_name(self),
FO.to_icon(self),
i,
)
def jax_func(self, axis_0: int | None, axis_1: int | None):
####################
# - Ops from Info
####################
@staticmethod
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
operations = []
# Slice
if info.dim_names:
operations.append(FO.SliceIdx)
# Pin
## PinLen1
## -> There must be a dimension with length 1.
if 1 in list(info.dim_lens.values()):
operations.append(FO.PinLen1)
## Pin | PinIdx
## -> There must be a dimension, full stop.
if info.dim_names:
operations += [FO.Pin, FO.PinIdx]
# Reinterpret
## Swap
## -> There must be at least two dimensions.
if len(info.dim_names) >= 2: # noqa: PLR2004
operations.append(FO.Swap)
## SetDim
## -> There must be a dimension to correct.
if info.dim_names:
operations.append(FO.SetDim)
return operations
####################
# - Computed Properties
####################
@property
def func_args(self) -> list[spux.MathType]:
FO = FilterOperation
return {
# Interpret
FilterOperation.DimToVec: lambda data: data,
FilterOperation.DimsToMat: lambda data: data,
# Dimensions
FilterOperation.PinLen1: lambda data: jnp.squeeze(data, axis_0),
FilterOperation.Pin: lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis_0
),
FilterOperation.Swap: lambda data: jnp.swapaxes(data, axis_0, axis_1),
# Pin
FO.Pin: [spux.MathType.Integer],
FO.PinIdx: [spux.MathType.Integer],
}.get(self, [])
####################
# - Methods
####################
@property
def num_dim_inputs(self) -> None:
FO = FilterOperation
return {
# Slice
FO.SliceIdx: 1,
# Pin
FO.PinLen1: 1,
FO.Pin: 1,
FO.PinIdx: 1,
# Reinterpret
FO.Swap: 2,
FO.SetDim: 1,
}[self]
def transform_info(self, info: ct.InfoFlow, dim_0: str, dim_1: str):
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
FO = FilterOperation
match self:
case FO.SliceIdx:
return info.dim_names
# PinLen1: Only allow dimensions with length=1.
case FO.PinLen1:
return [
dim_name
for dim_name in info.dim_names
if info.dim_lens[dim_name] == 1
]
# Pin: Only allow dimensions with known indexing.
case FO.Pin:
return [
dim_name
for dim_name in info.dim_names
if info.dim_has_coords[dim_name] != 0
]
case FO.PinIdx | FO.Swap:
return info.dim_names
case FO.SetDim:
return [
dim_name
for dim_name in info.dim_names
if info.dim_mathtypes[dim_name] == spux.MathType.Integer
]
return []
def are_dims_valid(
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
) -> bool:
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or (
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info)
)
####################
# - UI
####################
def jax_func(
self,
axis_0: int | None,
axis_1: int | None,
slice_tuple: tuple[int, int, int] | None = None,
):
FO = FilterOperation
return {
# Interpret
FilterOperation.DimToVec: lambda: info.shift_last_input,
FilterOperation.DimsToMat: lambda: info.shift_last_input.shift_last_input,
# Dimensions
FilterOperation.PinLen1: lambda: info.delete_dimension(dim_0),
FilterOperation.Pin: lambda: info.delete_dimension(dim_0),
FilterOperation.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
# Pin
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
),
# Pin
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
# Reinterpret
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
FO.SetDim: lambda expr: expr,
}[self]
def transform_info(
self,
info: ct.InfoFlow,
dim_0: str,
dim_1: str,
slice_tuple: tuple[int, int, int] | None = None,
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.LazyArrayRangeFlow]]
| None = None,
):
FO = FilterOperation
return {
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
# Pin
FO.PinLen1: lambda: info.delete_dimension(dim_0),
FO.Pin: lambda: info.delete_dimension(dim_0),
FO.PinIdx: lambda: info.delete_dimension(dim_0),
# Reinterpret
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
FO.SetDim: lambda: info.replace_dim(*corrected_dim),
}[self]()
@ -133,115 +272,192 @@ class FilterMathNode(base.MaxwellSimNode):
}
####################
# - Properties
# - Properties: Expr InfoFlow
####################
operation: FilterOperation = bl_cache.BLField(
FilterOperation.PinLen1,
prop_ui=True,
@events.on_value_changed(
socket_name={'Expr'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
input_sockets_optional={'Expr': True},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
info_pending = ct.FlowSignal.check_single(
input_sockets['Expr'], ct.FlowSignal.FlowPending
)
# Dimension Selection
dim_0: enum.StrEnum = bl_cache.BLField(enum_cb=lambda self, _: self.search_dims())
dim_1: enum.StrEnum = bl_cache.BLField(enum_cb=lambda self, _: self.search_dims())
if has_info and not info_pending:
self.expr_info = bl_cache.Signal.InvalidateCache
####################
# - Computed
####################
@property
def data_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info):
@bl_cache.cached_bl_property()
def expr_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
has_info = not ct.FlowSignal.check(info)
if has_info:
return info
return None
####################
# - Search Dimensions
# - Properties: Operation
####################
operation: FilterOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.expr_info is not None:
return [
operation.bl_enum_element(i)
for i, operation in enumerate(FilterOperation.by_info(self.expr_info))
]
return []
####################
# - Properties: Dimension Selection
####################
dim_0: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
dim_1: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_dims(),
cb_depends_on={'operation', 'expr_info'},
)
def search_dims(self) -> list[ct.BLEnumElement]:
if self.data_info is None:
if self.expr_info is not None and self.operation is not None:
return [
(dim_name, dim_name, dim_name, '', i)
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info))
]
return []
if self.operation == FilterOperation.PinLen1:
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
for dim_name in self.data_info.dim_names
if self.data_info.dim_lens[dim_name] == 1
####################
# - Properties: Slice
####################
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
####################
# - Properties: Unit
####################
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
sim_symbols.CommonSimSymbol.X
)
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
enum_cb=lambda self, _: self.search_valid_units(),
cb_depends_on={'set_dim_symbol'},
)
def search_valid_units(self) -> list[ct.BLEnumElement]:
"""Compute Blender enum elements of valid units for the current `physical_type`."""
physical_type = self.set_dim_symbol.sim_symbol.physical_type
if physical_type is not spux.PhysicalType.NonPhysical:
return [
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
for i, unit in enumerate(physical_type.valid_units)
]
elif self.operation in [FilterOperation.Pin, FilterOperation.Swap]:
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}"')
for dim_name in self.data_info.dim_names
]
else:
return []
return [(*dim, '', i) for i, dim in enumerate(dims)]
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
def set_dim_unit(self) -> spux.Unit | None:
if self.set_dim_active_unit is not None:
return spux.unit_str_to_unit(self.set_dim_active_unit)
return None
####################
# - UI
####################
def draw_label(self):
FO = FilterOperation
labels = {
FO.PinLen1: lambda: f'Filter: Pin {self.dim_0} (len=1)',
FO.Pin: lambda: f'Filter: Pin {self.dim_0}',
FO.Swap: lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
FO.DimToVec: lambda: 'Filter: -> Vector',
FO.DimsToMat: lambda: 'Filter: -> Matrix',
}
match self.operation:
# Slice
case FO.SliceIdx:
slice_str = ':'.join([str(v) for v in self.slice_tuple])
return f'Filter: {self.dim_0}[{slice_str}]'
if (label := labels.get(self.operation)) is not None:
return label()
# Pin
case FO.PinLen1:
return f'Filter: Pin {self.dim_0}[0]'
case FO.Pin:
return f'Filter: Pin {self.dim_0}[...]'
case FO.PinIdx:
pin_idx_axis = self._compute_input(
'Axis', kind=ct.FlowKind.Value, optional=True
)
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
if has_pin_idx_axis:
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]'
return self.bl_label
# Reinterpret
case FO.Swap:
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]'
case FO.SetDim:
return f'Filter: Set [{self.dim_0}]'
case _:
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
if self.operation in [FilterOperation.PinLen1, FilterOperation.Pin]:
if self.operation is not None:
match self.operation.num_dim_inputs:
case 1:
layout.prop(self, self.blfields['dim_0'], text='')
if self.operation == FilterOperation.Swap:
case 2:
row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='')
if self.operation is FilterOperation.SliceIdx:
layout.prop(self, self.blfields['slice_tuple'], text='')
if self.operation is FilterOperation.SetDim:
row = layout.row(align=True)
row.prop(self, self.blfields['set_dim_symbol'], text='')
row.prop(self, self.blfields['set_dim_active_unit'], text='')
####################
# - Events
####################
@events.on_value_changed(
# Trigger
socket_name='Expr',
prop_name={'operation'},
run_on_init=True,
)
def on_input_changed(self) -> None:
self.dim_0 = bl_cache.Signal.ResetEnumItems
self.dim_1 = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
socket_name='Expr',
prop_name={'dim_0', 'dim_1', 'operation'},
run_on_init=True,
prop_name={'operation', 'dim_0', 'dim_1'},
# Loaded
props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
)
def on_pin_changed(self, props: dict, input_sockets: dict):
def on_pin_factors_changed(self, props: dict, input_sockets: dict):
"""Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`.
To "pin" an axis, a particular index must be chosen to "extract".
One might choose axes of length 1 ("squeeze"), choose a particular index, or choose a coordinate that maps to a particular index.
Those last two options requires more information from the user: Which index?
Which coordinate?
To answer these questions, we create an appropriate loose input socket containing this data, so the user can make their decision.
"""
info = input_sockets['Expr']
has_info = not ct.FlowSignal.check(info)
if not has_info:
return
# "Dimensions"|"PIN": Add/Remove Input Socket
if props['operation'] == FilterOperation.Pin and props['dim_0'] is not None:
# Pin Dim by-Value: Synchronize Input Socket
## -> The user will be given a socket w/correct mathtype, unit, etc. .
## -> Internally, this value will map to a particular index.
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
# Deduce Pinned Information
pinned_unit = info.dim_units[props['dim_0']]
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
wanted_mathtype = (
spux.MathType.Complex
if pinned_mathtype == spux.MathType.Complex
@ -250,9 +466,11 @@ class FilterMathNode(base.MaxwellSimNode):
)
# Get Current and Wanted Socket Defs
## -> 'Value' may already exist. If not, all is well.
current_bl_socket = self.loose_input_sockets.get('Value')
# Determine Whether to Declare New Loose Input SOcket
# Determine Whether to Construct
## -> If nothing needs to change, then nothing changes.
if (
current_bl_socket is None
or current_bl_socket.size is not spux.NumberSize1D.Scalar
@ -262,22 +480,68 @@ class FilterMathNode(base.MaxwellSimNode):
self.loose_input_sockets = {
'Value': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
size=spux.NumberSize1D.Scalar,
physical_type=pinned_physical_type,
mathtype=wanted_mathtype,
default_unit=pinned_unit,
),
}
# Pin Dim by-Index: Synchronize Input Socket
## -> The user will be given a simple integer socket.
elif (
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None
):
current_bl_socket = self.loose_input_sockets.get('Axis')
if (
current_bl_socket is None
or current_bl_socket.size is not spux.NumberSize1D.Scalar
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
or current_bl_socket.mathtype != spux.MathType.Integer
):
self.loose_input_sockets = {
'Axis': sockets.ExprSocketDef(
active_kind=ct.FlowKind.Value,
mathtype=spux.MathType.Integer,
)
}
# Set Dim: Synchronize Input Socket
## -> The user must provide a () -> array.
## -> It must be of identical length to the replaced axis.
elif (
props['operation'] is FilterOperation.SetDim
and props['dim_0'] is not None
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
):
# Deduce Axis Information
current_bl_socket = self.loose_input_sockets.get('Dim')
if (
current_bl_socket is None
or current_bl_socket.active_kind != ct.FlowKind.LazyValueFunc
or current_bl_socket.mathtype != spux.MathType.Real
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
):
self.loose_input_sockets = {
'Dim': sockets.ExprSocketDef(
active_kind=ct.FlowKind.LazyValueFunc,
mathtype=spux.MathType.Real,
physical_type=spux.PhysicalType.NonPhysical,
show_info_columns=True,
)
}
# No Loose Value: Remove Input Sockets
elif self.loose_input_sockets:
self.loose_input_sockets = {}
####################
# - Output
# - FlowKind.Value|LazyValueFunc
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.LazyValueFunc,
props={'operation', 'dim_0', 'dim_1'},
props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
)
@ -296,82 +560,120 @@ class FilterMathNode(base.MaxwellSimNode):
has_lazy_value_func
and has_info
and operation is not None
and operation.are_dims_valid(dim_0, dim_1)
and operation.are_dims_valid(info, dim_0, dim_1)
):
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None
slice_tuple = (
props['slice_tuple']
if self.operation is FilterOperation.SliceIdx
else None
)
return lazy_value_func.compose_within(
operation.jax_func(axis_0, axis_1),
enclosing_func_args=[int] if operation == FilterOperation.Pin else [],
operation.jax_func(axis_0, axis_1, slice_tuple),
enclosing_func_args=operation.func_args,
supports_jax=True,
)
return ct.FlowSignal.FlowPending
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Array,
output_sockets={'Expr'},
output_socket_kinds={
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
)
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Expr'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
unit_system = unit_systems['BlenderUnits']
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.scaled_func_args(unit_system),
**params.scaled_func_kwargs(unit_system),
),
)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Info
# - FlowKind.Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Expr'},
input_socket_kinds={'Expr': ct.FlowKind.Info},
props={
'dim_0',
'dim_1',
'operation',
'slice_tuple',
'set_dim_symbol',
'set_dim_active_unit',
},
input_sockets={'Expr', 'Dim'},
input_socket_kinds={
'Expr': ct.FlowKind.Info,
'Dim': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params, ct.FlowKind.Info},
},
input_sockets_optional={'Dim': True},
)
def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
operation = props['operation']
info = input_sockets['Expr']
dim_coords = input_sockets['Dim'][ct.FlowKind.LazyValueFunc]
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
dim_symbol = props['set_dim_symbol']
dim_active_unit = props['set_dim_active_unit']
has_info = not ct.FlowSignal.check(info)
has_dim_coords = not ct.FlowSignal.check(dim_coords)
has_dim_params = not ct.FlowSignal.check(dim_params)
has_dim_info = not ct.FlowSignal.check(dim_info)
# Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
slice_tuple = props['slice_tuple']
if has_info and operation is not None:
# Set Dimension: Retrieve Array
if props['operation'] is FilterOperation.SetDim:
if (
has_info
and operation is not None
and operation.are_dims_valid(dim_0, dim_1)
dim_0 is not None
# Check Replaced Dimension
and has_dim_coords
and len(dim_coords.func_args) == 1
and dim_coords.func_args[0] is spux.MathType.Integer
and not dim_coords.func_kwargs
and dim_coords.supports_jax
# Check Params
and has_dim_params
and len(dim_params.func_args) == 1
and not dim_params.func_kwargs
# Check Info
and has_dim_info
):
return operation.transform_info(info, dim_0, dim_1)
# Retrieve Dimension Coordinate Array
## -> It must be strictly compatible.
values = dim_coords.func_jax(int(dim_params.func_args[0]))
if (
len(values.shape) != 1
or values.shape[0] != info.dim_lens[dim_0]
):
return ct.FlowSignal.FlowPending
# Transform Info w/Corrected Dimension
## -> The existing dimension will be replaced.
if dim_active_unit is not None:
dim_unit = spux.unit_str_to_unit(dim_active_unit)
else:
dim_unit = None
new_dim_idx = ct.ArrayFlow(
values=values,
unit=dim_unit,
)
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)]
return operation.transform_info(
info, dim_0, dim_1, corrected_dim=corrected_dim
)
return ct.FlowSignal.FlowPending
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Params
# - FlowKind.Params
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Expr', 'Value'},
input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Value': True},
input_sockets={'Expr', 'Value', 'Axis'},
input_socket_kinds={
'Expr': {ct.FlowKind.Info, ct.FlowKind.Params},
},
input_sockets_optional={'Value': True, 'Axis': True},
)
def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow:
operation = props['operation']
@ -388,20 +690,30 @@ class FilterMathNode(base.MaxwellSimNode):
has_info
and has_params
and operation is not None
and operation.are_dims_valid(dim_0, dim_1)
and operation.are_dims_valid(info, dim_0, dim_1)
):
## Pinned Value
# Retrieve Pinned Value
pinned_value = input_sockets['Value']
has_pinned_value = not ct.FlowSignal.check(pinned_value)
if props['operation'] == FilterOperation.Pin and has_pinned_value:
pinned_axis = input_sockets['Axis']
has_pinned_axis = not ct.FlowSignal.check(pinned_axis)
# Pin by-Value: Compute Nearest IDX
## -> Presume a sorted index array to be able to use binary search.
if props['operation'] is FilterOperation.Pin and has_pinned_value:
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
pinned_value, require_sorted=True
)
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
# Pin by-Index
if props['operation'] is FilterOperation.PinIdx and has_pinned_axis:
return params.compose_within(enclosing_func_args=[pinned_axis])
return params
return ct.FlowSignal.FlowPending

View File

@ -21,6 +21,7 @@ import typing as typ
import bpy
import jax.numpy as jnp
import jaxtyping as jtyp
import sympy as sp
import sympy.physics.units as spu
@ -47,13 +48,20 @@ class TransformOperation(enum.StrEnum):
InvFFT: Compute the inverse fourier transform of the input expression.
"""
# Index
# Covariant Transform
FreqToVacWL = enum.auto()
VacWLToFreq = enum.auto()
# Fold
IntDimToComplex = enum.auto()
DimToVec = enum.auto()
DimsToMat = enum.auto()
# Fourier
FFT1D = enum.auto()
InvFFT1D = enum.auto()
# Affine
# TODO: Affine
## TODO
####################
@ -63,9 +71,14 @@ class TransformOperation(enum.StrEnum):
def to_name(value: typ.Self) -> str:
TO = TransformOperation
return {
# By Number
# Covariant Transform
TO.FreqToVacWL: '𝑓 → λᵥ',
TO.VacWLToFreq: 'λᵥ → 𝑓',
# Fold
TO.IntDimToComplex: '',
TO.DimToVec: '→ Vector',
TO.DimsToMat: '→ Matrix',
# Fourier
TO.FFT1D: 't → 𝑓',
TO.InvFFT1D: '𝑓 → t',
}[value]
@ -92,7 +105,8 @@ class TransformOperation(enum.StrEnum):
TO = TransformOperation
operations = []
# Freq <-> VacWL
# Covariant Transform
## Freq <-> VacWL
for dim_name in info.dim_names:
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
operations.append(TO.FreqToVacWL)
@ -100,7 +114,23 @@ class TransformOperation(enum.StrEnum):
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
operations.append(TO.VacWLToFreq)
# 1D Fourier
# Fold
## (Last) Int Dim (=2) to Complex
if len(info.dim_names) >= 1:
last_dim_name = info.dim_names[-1]
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
operations.append(TO.IntDimToComplex)
## To Vector
if len(info.dim_names) >= 1:
operations.append(TO.DimToVec)
## To Matrix
if len(info.dim_names) >= 2: # noqa: PLR2004
operations.append(TO.DimsToMat)
# Fourier
## 1D Fourier
if info.dim_names:
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
if last_physical_type == spux.PhysicalType.Time:
@ -117,9 +147,13 @@ class TransformOperation(enum.StrEnum):
def sp_func(self):
TO = TransformOperation
return {
# Index
# Covariant Transform
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fold
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: sp.fourier_transform(
expr, sim_symbols.t, sim_symbols.freq
@ -133,15 +167,26 @@ class TransformOperation(enum.StrEnum):
def jax_func(self):
TO = TransformOperation
return {
# Index
# Covariant Transform
TO.FreqToVacWL: lambda expr: expr,
TO.VacWLToFreq: lambda expr: expr,
# Fold
## -> To Complex: With a little imagination, this is a noop :)
## -> **Requires** dims[-1] to be integer-indexed w/length of 2.
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(),
TO.DimToVec: lambda expr: expr,
TO.DimsToMat: lambda expr: expr,
# Fourier
TO.FFT1D: lambda expr: jnp.fft(expr),
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
}[self]
def transform_info(self, info: ct.InfoFlow | None) -> ct.InfoFlow | None:
def transform_info(
self,
info: ct.InfoFlow | None,
data: jtyp.Shaped[jtyp.Array, '...'] | None = None,
unit: spux.Unit | None = None,
) -> ct.InfoFlow | None:
TO = TransformOperation
if not info.dim_names:
return None
@ -169,6 +214,12 @@ class TransformOperation(enum.StrEnum):
),
],
),
# Fold
TO.IntDimToComplex: lambda: info.delete_dimension(
info.dim_names[-1]
).set_output_mathtype(spux.MathType.Complex),
TO.DimToVec: lambda: info.shift_last_input,
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input,
# Fourier
TO.FFT1D: lambda: info.replace_dim(
info.dim_names[-1],
@ -216,7 +267,7 @@ class TransformMathNode(base.MaxwellSimNode):
}
####################
# - Properties
# - Properties: Expr InfoFlow
####################
@events.on_value_changed(
socket_name={'Expr'},
@ -243,6 +294,9 @@ class TransformMathNode(base.MaxwellSimNode):
return None
####################
# - Properties: Operation
####################
operation: TransformOperation = bl_cache.BLField(
enum_cb=lambda self, _: self.search_operations(),
cb_depends_on={'expr_info'},
@ -258,9 +312,6 @@ class TransformMathNode(base.MaxwellSimNode):
]
return []
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
####################
# - UI
####################
@ -339,6 +390,7 @@ class TransformMathNode(base.MaxwellSimNode):
if has_info and operation is not None:
transformed_info = operation.transform_info(info)
if transformed_info is None:
return ct.FlowSignal.FlowPending
return transformed_info

View File

@ -14,11 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import tidy_3d_file_importer
from . import data_file_importer, tidy_3d_file_importer
BL_REGISTER = [
*data_file_importer.BL_REGISTER,
*tidy_3d_file_importer.BL_REGISTER,
]
BL_NODES = {
**data_file_importer.BL_NODES,
**tidy_3d_file_importer.BL_NODES,
}

View File

@ -0,0 +1,356 @@
# blender_maxwell
# Copyright (C) 2024 blender_maxwell Project Contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import typing as typ
from pathlib import Path
import bpy
import jax.numpy as jnp
import jaxtyping as jtyp
import numpy as np
import pandas as pd
import sympy as sp
import tidy3d as td
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
####################
# - Data File Extensions
####################
_DATA_FILE_EXTS = {
'.txt',
'.txt.gz',
'.csv',
'.npy',
}
class DataFileExt(enum.StrEnum):
Txt = enum.auto()
TxtGz = enum.auto()
Csv = enum.auto()
Npy = enum.auto()
####################
# - Enum Elements
####################
@staticmethod
def to_name(v: typ.Self) -> str:
return DataFileExt(v).extension
@staticmethod
def to_icon(v: typ.Self) -> str:
return ''
####################
# - Computed Properties
####################
@property
def extension(self) -> str:
"""Map to the actual string extension."""
E = DataFileExt
return {
E.Txt: '.txt',
E.TxtGz: '.txt.gz',
E.Csv: '.csv',
E.Npy: '.npy',
}[self]
@property
def loader(self) -> typ.Callable[[Path], jtyp.Shaped[jtyp.Array, '...']]:
def load_txt(path: Path):
return jnp.asarray(np.loadtxt(path))
def load_csv(path: Path):
return jnp.asarray(pd.read_csv(path).values)
def load_npy(path: Path):
return jnp.load(path)
E = DataFileExt
return {
E.Txt: load_txt,
E.TxtGz: load_txt,
E.Csv: load_csv,
E.Npy: load_npy,
}[self]
@property
def loader_is_jax_compatible(self) -> bool:
E = DataFileExt
return {
E.Txt: True,
E.TxtGz: True,
E.Csv: False,
E.Npy: True,
}[self]
####################
# - Creation
####################
@staticmethod
def from_ext(ext: str) -> typ.Self | None:
return {
_ext: _data_file_ext
for _data_file_ext, _ext in {
k: k.extension for k in list(DataFileExt)
}.items()
}.get(ext)
@staticmethod
def from_path(path: Path) -> typ.Self | None:
if DataFileExt.is_path_compatible(path):
data_file_ext = DataFileExt.from_ext(''.join(path.suffixes))
if data_file_ext is not None:
return data_file_ext
msg = f'DataFileExt: Path "{path}" is compatible, but could not find valid extension'
raise RuntimeError(msg)
return None
####################
# - Compatibility
####################
@staticmethod
def is_ext_compatible(ext: str):
return ext in _DATA_FILE_EXTS
@staticmethod
def is_path_compatible(path: Path):
return path.is_file() and DataFileExt.is_ext_compatible(''.join(path.suffixes))
####################
# - Node
####################
class DataFileImporterNode(base.MaxwellSimNode):
node_type = ct.NodeType.DataFileImporter
bl_label = 'Data File Importer'
input_sockets: typ.ClassVar = {
'File Path': sockets.FilePathSocketDef(),
}
output_sockets: typ.ClassVar = {
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
}
####################
# - Properties
####################
@events.on_value_changed(
socket_name={'File Path'},
input_sockets={'File Path'},
input_socket_kinds={'File Path': ct.FlowKind.Value},
input_sockets_optional={'File Path': True},
)
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
has_file_path = ct.FlowSignal.check_single(
input_sockets['File Path'], ct.FlowSignal.FlowPending
)
if has_file_path:
self.file_path = bl_cache.Signal.InvalidateCache
@bl_cache.cached_bl_property()
def file_path(self) -> Path:
"""Retrieve the input file path."""
file_path = self._compute_input(
'File Path', kind=ct.FlowKind.Value, optional=True
)
has_file_path = not ct.FlowSignal.check(file_path)
if has_file_path:
return file_path
return None
@bl_cache.cached_bl_property(depends_on={'file_path'})
def data_file_ext(self) -> DataFileExt | None:
"""Retrieve the file extension by concatenating all suffixes."""
if self.file_path is not None:
return DataFileExt.from_path(self.file_path)
return None
####################
# - Output Info
####################
@bl_cache.cached_bl_property(depends_on={'file_path'})
def expr_info(self) -> ct.InfoFlow | None:
"""Retrieve the output expression's `InfoFlow`."""
info = self.compute_output('Expr', kind=ct.FlowKind.Info)
has_info = not ct.FlowKind.check(info)
if has_info:
return info
return None
####################
# - UI
####################
def draw_label(self):
"""Show the extracted file name (w/extension) in the node's header label.
Notes:
Called by Blender to determine the text to place in the node's header.
"""
if self.file_path is not None:
return 'Load File: ' + self.file_path.name
return self.bl_label
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
"""Show information about the loaded file."""
if self.data_file_ext is not None:
box = layout.box()
row = box.row()
row.alignment = 'CENTER'
row.label(text='Data File')
row = box.row()
row.alignment = 'CENTER'
row.label(text=self.file_path.name)
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
pass
####################
# - Events
####################
@events.on_value_changed(
socket_name='File Path',
input_sockets={'File Path'},
)
def on_file_changed(self, input_sockets) -> None:
pass
####################
# - FlowKind.Array|LazyValueFunc
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.LazyValueFunc,
input_sockets={'File Path'},
)
def compute_func(self, input_sockets: dict) -> td.Simulation:
"""Declare a lazy, composable function that returns the loaded data.
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
file_path = input_sockets['File Path']
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
if has_file_path:
data_file_ext = DataFileExt.from_path(file_path)
if data_file_ext is not None:
# Jax Compatibility: Lazy Data Loading
## -> Delay loading of data from file as long as we can.
if data_file_ext.loader_is_jax_compatible:
return ct.LazyValueFuncFlow(
func=lambda: data_file_ext.loader(file_path),
supports_jax=True,
)
# No Jax Compatibility: Eager Data Loading
## -> Load the data now and bind it.
data = data_file_ext.loader(file_path)
return ct.LazyValueFuncFlow(func=lambda: data, supports_jax=True)
return ct.FlowSignal.FlowPending
return ct.FlowSignal.FlowPending
####################
# - FlowKind.Params|Info
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Params,
)
def compute_params(self) -> ct.ParamsFlow:
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
return ct.ParamsFlow()
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Info,
output_sockets={'Expr'},
output_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc},
)
def compute_info(self, output_sockets) -> ct.InfoFlow:
"""Declare an `InfoFlow` based on the data shape.
This currently requires computing the data.
Note, however, that the incremental cache causes this computation only to happen once when a file is selected.
Returns:
A completely empty `ParamsFlow`, ready to be composed.
"""
expr = output_sockets['Expr']
has_expr_func = not ct.FlowSignal.check(expr)
if has_expr_func:
data = expr.func_jax()
# Deduce Dimensionality
_shape = data.shape
shape = _shape if _shape is not None else ()
dim_names = [f'a{i}' for i in range(len(shape))]
# Return InfoFlow
## -> TODO: How to interpret the data should be user-defined.
## -> -- This may require those nice dynamic symbols.
return ct.InfoFlow(
dim_names=dim_names, ## TODO: User
dim_idx={
dim_name: ct.LazyArrayRangeFlow(
start=sp.S(0), ## TODO: User
stop=sp.S(shape[i] - 1), ## TODO: User
steps=shape[dim_names.index(dim_name)],
unit=None, ## TODO: User
)
for i, dim_name in enumerate(dim_names)
},
output_name='_',
output_shape=None,
output_mathtype=spux.MathType.Real, ## TODO: User
output_unit=None, ## TODO: User
)
return ct.FlowSignal.FlowPending
####################
# - Blender Registration
####################
BL_REGISTER = [
DataFileImporterNode,
]
BL_NODES = {
ct.NodeType.DataFileImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_FILEIMPORTERS)
}

View File

@ -234,7 +234,7 @@ def plot_curves_2d(
y_unit = info.output_unit
for category in range(data.shape[1]):
ax.plot(data[:, 0], data[:, 1])
ax.plot(info.dim_idx_arrays[0], data[:, category])
ax.set_title('2D Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
@ -250,8 +250,9 @@ def plot_filled_curves_2d(
y_name = info.output_name
y_unit = info.output_unit
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
ax.set_title('2D Curves')
shared_x_idx = info.dim_idx_arrays[0]
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
ax.set_title('2D Filled Curves')
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))

View File

@ -109,6 +109,10 @@ class SimSymbol:
####################
# - Properties
####################
@property
def name(self) -> str:
return self.sim_node_name.name
@property
def domain(self) -> sp.Interval | sp.Set:
"""Return the domain of valid values for the symbol.
@ -235,6 +239,10 @@ class CommonSimSymbol(enum.StrEnum):
####################
# - Properties
####################
@property
def name(self) -> str:
return self.sim_symbol.name
@property
def sim_symbol_name(self) -> str:
SSN = SimSymbolName