feat: implemented data loading w/new math ops
We implemented a node to load various kinds of data, notably `.npy`, `.txt`, `.txt.gz`, and `.csv`. The `DataFileImporterNode` really should expose some settings for setting name/mathtype/physical type/unit of each unit, and/or treating a column from 2D data as index coordinates. But the nuances of doing this in a manner general enough to deal with !=2D data was a lot, and we needed similar abilities in the general math system anyway. So, we delved back into the `FilterMathNode` and a little into the `TransformMathNode`. Fundamentally, a few difficult operations came out of this: - Filter:SliceIdx: Slice an array using the usual syntax, as baked into the function. - Filter:PinIdx: Pin an axis by its actual index. - Filter:SetDim: Set the `InfoFlow` index coordinates of an axis to a specific, loose-socket provided 1D array, and use a common symbol to set the name+physical type (and allow specifying an appropriate unit). - Transform:IntDimToComplex: Fold a last length-2 integer-indexed axis into a real output type, which removes the dimension and produces a complex output type. Essentially, this is equivalent to folding it as a vector and treating the `R^2` numbers as real/imaginary, except this is more explicit. By combining all of these, we managed to process and constrain the medium data to be a well-suited, unit-aware (**though not on the output (yet)**) `wl->C` tensor. In particular, the slicing is nice for avoiding discontinuities. Workflow-wise, we'll see how important these are / what else we might want. Also, it turns out Blender's text editor is really quite nice for light data-text viewing.main
parent
a66a28da27
commit
0f2f494868
|
@ -122,7 +122,6 @@ class ArrayFlow:
|
|||
if self.unit is not None
|
||||
else rescale_func(a * self.unit)
|
||||
)
|
||||
log.critical([self.unit, new_unit, rescale_expr])
|
||||
_rescale_func = sp.lambdify(a, rescale_expr, 'jax')
|
||||
values = _rescale_func(self.values)
|
||||
|
||||
|
@ -132,3 +131,13 @@ class ArrayFlow:
|
|||
unit=new_unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
def __getitem__(self, subscript: slice):
|
||||
if isinstance(subscript, slice):
|
||||
return ArrayFlow(
|
||||
values=self.values[subscript],
|
||||
unit=self.unit,
|
||||
is_sorted=self.is_sorted,
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -39,6 +39,16 @@ class InfoFlow:
|
|||
default_factory=dict
|
||||
) ## TODO: Rename to dim_idxs
|
||||
|
||||
@functools.cached_property
|
||||
def dim_has_coords(self) -> dict[str, int]:
|
||||
return {
|
||||
dim_name: not (
|
||||
isinstance(dim_idx, LazyArrayRangeFlow)
|
||||
and (dim_idx.start.is_infinite or dim_idx.stop.is_infinite)
|
||||
)
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
}
|
||||
|
||||
@functools.cached_property
|
||||
def dim_lens(self) -> dict[str, int]:
|
||||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||
|
@ -99,9 +109,29 @@ class InfoFlow:
|
|||
####################
|
||||
# - Methods
|
||||
####################
|
||||
def slice_dim(self, dim_name: str, slice_tuple: tuple[int, int, int]) -> typ.Self:
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
dim_idx={
|
||||
_dim_name: (
|
||||
dim_idx
|
||||
if _dim_name != dim_name
|
||||
else dim_idx[slice_tuple[0] : slice_tuple[1] : slice_tuple[2]]
|
||||
)
|
||||
for _dim_name, dim_idx in self.dim_idx.items()
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
||||
def replace_dim(
|
||||
self, old_dim_name: str, new_dim_idx: tuple[str, ArrayFlow | LazyArrayRangeFlow]
|
||||
) -> typ.Self:
|
||||
"""Replace a dimension (and its indexing) with a new name and index array/range."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=[
|
||||
|
@ -122,6 +152,7 @@ class InfoFlow:
|
|||
)
|
||||
|
||||
def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self:
|
||||
"""Replace several dimensional indices with new index arrays/ranges."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
|
@ -156,7 +187,7 @@ class InfoFlow:
|
|||
)
|
||||
|
||||
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
|
||||
"""Delete a dimension."""
|
||||
"""Swap the position of two dimensions."""
|
||||
|
||||
# Compute Swapped Dimension Name List
|
||||
def name_swapper(dim_name):
|
||||
|
@ -181,7 +212,7 @@ class InfoFlow:
|
|||
)
|
||||
|
||||
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
|
||||
"""Set the MathType of a particular output name."""
|
||||
"""Set the MathType of the output."""
|
||||
return InfoFlow(
|
||||
dim_names=self.dim_names,
|
||||
dim_idx=self.dim_idx,
|
||||
|
@ -198,6 +229,7 @@ class InfoFlow:
|
|||
collapsed_mathtype: spux.MathType,
|
||||
collapsed_unit: spux.Unit,
|
||||
) -> typ.Self:
|
||||
"""Replace the (scalar) output with the given corrected values."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
|
|
|
@ -418,7 +418,11 @@ class LazyArrayRangeFlow:
|
|||
self,
|
||||
symbol_values: dict[spux.Symbol, typ.Any] = MappingProxyType({}),
|
||||
) -> ArrayFlow | LazyValueFuncFlow:
|
||||
return (self.realize_stop() - self.realize_start()) / self.steps
|
||||
raw_step_size = (self.realize_stop() - self.realize_start() + 1) / self.steps
|
||||
|
||||
if self.mathtype is spux.MathType.Integer and raw_step_size.is_integer():
|
||||
return int(raw_step_size)
|
||||
return raw_step_size
|
||||
|
||||
def realize(
|
||||
self,
|
||||
|
@ -463,3 +467,28 @@ class LazyArrayRangeFlow:
|
|||
@functools.cached_property
|
||||
def realize_array(self) -> ArrayFlow:
|
||||
return self.realize()
|
||||
|
||||
def __getitem__(self, subscript: slice):
|
||||
if isinstance(subscript, slice) and self.scaling == ScalingMode.Lin:
|
||||
# Parse Slice
|
||||
start = subscript.start if subscript.start is not None else 0
|
||||
stop = subscript.stop if subscript.stop is not None else self.steps
|
||||
step = subscript.step if subscript.step is not None else 1
|
||||
|
||||
slice_steps = (stop - start + step - 1) // step
|
||||
|
||||
# Compute New Start/Stop
|
||||
step_size = self.realize_step_size()
|
||||
new_start = step_size * start
|
||||
new_stop = new_start + step_size * slice_steps
|
||||
|
||||
return LazyArrayRangeFlow(
|
||||
start=sp.S(new_start),
|
||||
stop=sp.S(new_stop),
|
||||
steps=slice_steps,
|
||||
scaling=self.scaling,
|
||||
unit=self.unit,
|
||||
symbols=self.symbols,
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -58,9 +58,11 @@ class ParamsFlow:
|
|||
|
||||
## TODO: MutableDenseMatrix causes error with 'in' check bc it isn't hashable.
|
||||
return [
|
||||
(
|
||||
spux.scale_to_unit_system(arg, unit_system, use_jax_array=True)
|
||||
if arg not in symbol_values
|
||||
else symbol_values[arg]
|
||||
)
|
||||
for arg in self.func_args
|
||||
]
|
||||
|
||||
|
|
|
@ -20,9 +20,11 @@ import enum
|
|||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.lax as jlax
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import bl_cache, logger, sim_symbols
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
|
@ -43,64 +45,201 @@ class FilterOperation(enum.StrEnum):
|
|||
Swap: Swap the positions of two dimensions.
|
||||
"""
|
||||
|
||||
# Dimensions
|
||||
# Slice
|
||||
SliceIdx = enum.auto()
|
||||
|
||||
# Pin
|
||||
PinLen1 = enum.auto()
|
||||
Pin = enum.auto()
|
||||
PinIdx = enum.auto()
|
||||
|
||||
# Reinterpret
|
||||
Swap = enum.auto()
|
||||
SetDim = enum.auto()
|
||||
|
||||
# Fold
|
||||
DimToVec = enum.auto()
|
||||
DimsToMat = enum.auto()
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
FO = FilterOperation
|
||||
return {
|
||||
# Dimensions
|
||||
# Slice
|
||||
FO.SliceIdx: 'a[...]',
|
||||
# Pin
|
||||
FO.PinLen1: 'pinₐ =1',
|
||||
FO.Pin: 'pinₐ ≈v',
|
||||
FO.PinIdx: 'pinₐ =a[v]',
|
||||
# Reinterpret
|
||||
FO.Swap: 'a₁ ↔ a₂',
|
||||
# Interpret
|
||||
FO.DimToVec: '→ Vector',
|
||||
FO.DimsToMat: '→ Matrix',
|
||||
FO.SetDim: 'setₐ =v',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
def to_icon(value: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
def are_dims_valid(self, dim_0: int | None, dim_1: int | None):
|
||||
return not (
|
||||
(
|
||||
dim_0 is None
|
||||
and self
|
||||
in [FilterOperation.PinLen1, FilterOperation.Pin, FilterOperation.Swap]
|
||||
)
|
||||
or (dim_1 is None and self == FilterOperation.Swap)
|
||||
def bl_enum_element(self, i: int) -> ct.BLEnumElement:
|
||||
FO = FilterOperation
|
||||
return (
|
||||
str(self),
|
||||
FO.to_name(self),
|
||||
FO.to_name(self),
|
||||
FO.to_icon(self),
|
||||
i,
|
||||
)
|
||||
|
||||
def jax_func(self, axis_0: int | None, axis_1: int | None):
|
||||
####################
|
||||
# - Ops from Info
|
||||
####################
|
||||
@staticmethod
|
||||
def by_info(info: ct.InfoFlow) -> list[typ.Self]:
|
||||
FO = FilterOperation
|
||||
operations = []
|
||||
|
||||
# Slice
|
||||
if info.dim_names:
|
||||
operations.append(FO.SliceIdx)
|
||||
|
||||
# Pin
|
||||
## PinLen1
|
||||
## -> There must be a dimension with length 1.
|
||||
if 1 in list(info.dim_lens.values()):
|
||||
operations.append(FO.PinLen1)
|
||||
|
||||
## Pin | PinIdx
|
||||
## -> There must be a dimension, full stop.
|
||||
if info.dim_names:
|
||||
operations += [FO.Pin, FO.PinIdx]
|
||||
|
||||
# Reinterpret
|
||||
## Swap
|
||||
## -> There must be at least two dimensions.
|
||||
if len(info.dim_names) >= 2: # noqa: PLR2004
|
||||
operations.append(FO.Swap)
|
||||
|
||||
## SetDim
|
||||
## -> There must be a dimension to correct.
|
||||
if info.dim_names:
|
||||
operations.append(FO.SetDim)
|
||||
|
||||
return operations
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def func_args(self) -> list[spux.MathType]:
|
||||
FO = FilterOperation
|
||||
return {
|
||||
# Interpret
|
||||
FilterOperation.DimToVec: lambda data: data,
|
||||
FilterOperation.DimsToMat: lambda data: data,
|
||||
# Dimensions
|
||||
FilterOperation.PinLen1: lambda data: jnp.squeeze(data, axis_0),
|
||||
FilterOperation.Pin: lambda data, fixed_axis_idx: jnp.take(
|
||||
data, fixed_axis_idx, axis=axis_0
|
||||
),
|
||||
FilterOperation.Swap: lambda data: jnp.swapaxes(data, axis_0, axis_1),
|
||||
# Pin
|
||||
FO.Pin: [spux.MathType.Integer],
|
||||
FO.PinIdx: [spux.MathType.Integer],
|
||||
}.get(self, [])
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
####################
|
||||
@property
|
||||
def num_dim_inputs(self) -> None:
|
||||
FO = FilterOperation
|
||||
return {
|
||||
# Slice
|
||||
FO.SliceIdx: 1,
|
||||
# Pin
|
||||
FO.PinLen1: 1,
|
||||
FO.Pin: 1,
|
||||
FO.PinIdx: 1,
|
||||
# Reinterpret
|
||||
FO.Swap: 2,
|
||||
FO.SetDim: 1,
|
||||
}[self]
|
||||
|
||||
def transform_info(self, info: ct.InfoFlow, dim_0: str, dim_1: str):
|
||||
def valid_dims(self, info: ct.InfoFlow) -> list[typ.Self]:
|
||||
FO = FilterOperation
|
||||
match self:
|
||||
case FO.SliceIdx:
|
||||
return info.dim_names
|
||||
|
||||
# PinLen1: Only allow dimensions with length=1.
|
||||
case FO.PinLen1:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_lens[dim_name] == 1
|
||||
]
|
||||
|
||||
# Pin: Only allow dimensions with known indexing.
|
||||
case FO.Pin:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_has_coords[dim_name] != 0
|
||||
]
|
||||
|
||||
case FO.PinIdx | FO.Swap:
|
||||
return info.dim_names
|
||||
|
||||
case FO.SetDim:
|
||||
return [
|
||||
dim_name
|
||||
for dim_name in info.dim_names
|
||||
if info.dim_mathtypes[dim_name] == spux.MathType.Integer
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
def are_dims_valid(
|
||||
self, info: ct.InfoFlow, dim_0: str | None, dim_1: str | None
|
||||
) -> bool:
|
||||
"""Check whether the given dimension inputs are valid in the context of this operation, and of the information."""
|
||||
return (self.num_dim_inputs in [1, 2] and dim_0 in self.valid_dims(info)) or (
|
||||
self.num_dim_inputs == 2 and dim_1 in self.valid_dims(info)
|
||||
)
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def jax_func(
|
||||
self,
|
||||
axis_0: int | None,
|
||||
axis_1: int | None,
|
||||
slice_tuple: tuple[int, int, int] | None = None,
|
||||
):
|
||||
FO = FilterOperation
|
||||
return {
|
||||
# Interpret
|
||||
FilterOperation.DimToVec: lambda: info.shift_last_input,
|
||||
FilterOperation.DimsToMat: lambda: info.shift_last_input.shift_last_input,
|
||||
# Dimensions
|
||||
FilterOperation.PinLen1: lambda: info.delete_dimension(dim_0),
|
||||
FilterOperation.Pin: lambda: info.delete_dimension(dim_0),
|
||||
FilterOperation.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
||||
# Pin
|
||||
FO.SliceIdx: lambda expr: jlax.slice_in_dim(
|
||||
expr, slice_tuple[0], slice_tuple[1], slice_tuple[2], axis=axis_0
|
||||
),
|
||||
# Pin
|
||||
FO.PinLen1: lambda expr: jnp.squeeze(expr, axis_0),
|
||||
FO.Pin: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||
FO.PinIdx: lambda expr, idx: jnp.take(expr, idx, axis=axis_0),
|
||||
# Reinterpret
|
||||
FO.Swap: lambda expr: jnp.swapaxes(expr, axis_0, axis_1),
|
||||
FO.SetDim: lambda expr: expr,
|
||||
}[self]
|
||||
|
||||
def transform_info(
|
||||
self,
|
||||
info: ct.InfoFlow,
|
||||
dim_0: str,
|
||||
dim_1: str,
|
||||
slice_tuple: tuple[int, int, int] | None = None,
|
||||
corrected_dim: tuple[str, tuple[str, ct.ArrayFlow | ct.LazyArrayRangeFlow]]
|
||||
| None = None,
|
||||
):
|
||||
FO = FilterOperation
|
||||
return {
|
||||
FO.SliceIdx: lambda: info.slice_dim(dim_0, slice_tuple),
|
||||
# Pin
|
||||
FO.PinLen1: lambda: info.delete_dimension(dim_0),
|
||||
FO.Pin: lambda: info.delete_dimension(dim_0),
|
||||
FO.PinIdx: lambda: info.delete_dimension(dim_0),
|
||||
# Reinterpret
|
||||
FO.Swap: lambda: info.swap_dimensions(dim_0, dim_1),
|
||||
FO.SetDim: lambda: info.replace_dim(*corrected_dim),
|
||||
}[self]()
|
||||
|
||||
|
||||
|
@ -133,115 +272,192 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
# - Properties: Expr InfoFlow
|
||||
####################
|
||||
operation: FilterOperation = bl_cache.BLField(
|
||||
FilterOperation.PinLen1,
|
||||
prop_ui=True,
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
input_sockets_optional={'Expr': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_info = not ct.FlowSignal.check(input_sockets['Expr'])
|
||||
|
||||
info_pending = ct.FlowSignal.check_single(
|
||||
input_sockets['Expr'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
# Dimension Selection
|
||||
dim_0: enum.StrEnum = bl_cache.BLField(enum_cb=lambda self, _: self.search_dims())
|
||||
dim_1: enum.StrEnum = bl_cache.BLField(enum_cb=lambda self, _: self.search_dims())
|
||||
if has_info and not info_pending:
|
||||
self.expr_info = bl_cache.Signal.InvalidateCache
|
||||
|
||||
####################
|
||||
# - Computed
|
||||
####################
|
||||
@property
|
||||
def data_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info)
|
||||
if not ct.FlowSignal.check(info):
|
||||
@bl_cache.cached_bl_property()
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Expr', kind=ct.FlowKind.Info, optional=True)
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if has_info:
|
||||
return info
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Search Dimensions
|
||||
# - Properties: Operation
|
||||
####################
|
||||
operation: FilterOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_info'},
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.expr_info is not None:
|
||||
return [
|
||||
operation.bl_enum_element(i)
|
||||
for i, operation in enumerate(FilterOperation.by_info(self.expr_info))
|
||||
]
|
||||
return []
|
||||
|
||||
####################
|
||||
# - Properties: Dimension Selection
|
||||
####################
|
||||
dim_0: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_dims(),
|
||||
cb_depends_on={'operation', 'expr_info'},
|
||||
)
|
||||
dim_1: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_dims(),
|
||||
cb_depends_on={'operation', 'expr_info'},
|
||||
)
|
||||
|
||||
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||
if self.data_info is None:
|
||||
if self.expr_info is not None and self.operation is not None:
|
||||
return [
|
||||
(dim_name, dim_name, dim_name, '', i)
|
||||
for i, dim_name in enumerate(self.operation.valid_dims(self.expr_info))
|
||||
]
|
||||
return []
|
||||
|
||||
if self.operation == FilterOperation.PinLen1:
|
||||
dims = [
|
||||
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
|
||||
for dim_name in self.data_info.dim_names
|
||||
if self.data_info.dim_lens[dim_name] == 1
|
||||
####################
|
||||
# - Properties: Slice
|
||||
####################
|
||||
slice_tuple: tuple[int, int, int] = bl_cache.BLField([0, 1, 1])
|
||||
|
||||
####################
|
||||
# - Properties: Unit
|
||||
####################
|
||||
set_dim_symbol: sim_symbols.CommonSimSymbol = bl_cache.BLField(
|
||||
sim_symbols.CommonSimSymbol.X
|
||||
)
|
||||
|
||||
set_dim_active_unit: enum.StrEnum = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_valid_units(),
|
||||
cb_depends_on={'set_dim_symbol'},
|
||||
)
|
||||
|
||||
def search_valid_units(self) -> list[ct.BLEnumElement]:
|
||||
"""Compute Blender enum elements of valid units for the current `physical_type`."""
|
||||
physical_type = self.set_dim_symbol.sim_symbol.physical_type
|
||||
if physical_type is not spux.PhysicalType.NonPhysical:
|
||||
return [
|
||||
(sp.sstr(unit), spux.sp_to_str(unit), sp.sstr(unit), '', i)
|
||||
for i, unit in enumerate(physical_type.valid_units)
|
||||
]
|
||||
elif self.operation in [FilterOperation.Pin, FilterOperation.Swap]:
|
||||
dims = [
|
||||
(dim_name, dim_name, f'Dimension "{dim_name}"')
|
||||
for dim_name in self.data_info.dim_names
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(*dim, '', i) for i, dim in enumerate(dims)]
|
||||
@bl_cache.cached_bl_property(depends_on={'set_dim_active_unit'})
|
||||
def set_dim_unit(self) -> spux.Unit | None:
|
||||
if self.set_dim_active_unit is not None:
|
||||
return spux.unit_str_to_unit(self.set_dim_active_unit)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
FO = FilterOperation
|
||||
labels = {
|
||||
FO.PinLen1: lambda: f'Filter: Pin {self.dim_0} (len=1)',
|
||||
FO.Pin: lambda: f'Filter: Pin {self.dim_0}',
|
||||
FO.Swap: lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
|
||||
FO.DimToVec: lambda: 'Filter: -> Vector',
|
||||
FO.DimsToMat: lambda: 'Filter: -> Matrix',
|
||||
}
|
||||
match self.operation:
|
||||
# Slice
|
||||
case FO.SliceIdx:
|
||||
slice_str = ':'.join([str(v) for v in self.slice_tuple])
|
||||
return f'Filter: {self.dim_0}[{slice_str}]'
|
||||
|
||||
if (label := labels.get(self.operation)) is not None:
|
||||
return label()
|
||||
# Pin
|
||||
case FO.PinLen1:
|
||||
return f'Filter: Pin {self.dim_0}[0]'
|
||||
case FO.Pin:
|
||||
return f'Filter: Pin {self.dim_0}[...]'
|
||||
case FO.PinIdx:
|
||||
pin_idx_axis = self._compute_input(
|
||||
'Axis', kind=ct.FlowKind.Value, optional=True
|
||||
)
|
||||
has_pin_idx_axis = not ct.FlowSignal.check(pin_idx_axis)
|
||||
if has_pin_idx_axis:
|
||||
return f'Filter: Pin {self.dim_0}[{pin_idx_axis}]'
|
||||
return self.bl_label
|
||||
|
||||
# Reinterpret
|
||||
case FO.Swap:
|
||||
return f'Filter: Swap [{self.dim_0}]|[{self.dim_1}]'
|
||||
case FO.SetDim:
|
||||
return f'Filter: Set [{self.dim_0}]'
|
||||
|
||||
case _:
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
if self.operation in [FilterOperation.PinLen1, FilterOperation.Pin]:
|
||||
if self.operation is not None:
|
||||
match self.operation.num_dim_inputs:
|
||||
case 1:
|
||||
layout.prop(self, self.blfields['dim_0'], text='')
|
||||
|
||||
if self.operation == FilterOperation.Swap:
|
||||
case 2:
|
||||
row = layout.row(align=True)
|
||||
row.prop(self, self.blfields['dim_0'], text='')
|
||||
row.prop(self, self.blfields['dim_1'], text='')
|
||||
|
||||
if self.operation is FilterOperation.SliceIdx:
|
||||
layout.prop(self, self.blfields['slice_tuple'], text='')
|
||||
|
||||
if self.operation is FilterOperation.SetDim:
|
||||
row = layout.row(align=True)
|
||||
row.prop(self, self.blfields['set_dim_symbol'], text='')
|
||||
row.prop(self, self.blfields['set_dim_active_unit'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name='Expr',
|
||||
prop_name={'operation'},
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_input_changed(self) -> None:
|
||||
self.dim_0 = bl_cache.Signal.ResetEnumItems
|
||||
self.dim_1 = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name='Expr',
|
||||
prop_name={'dim_0', 'dim_1', 'operation'},
|
||||
run_on_init=True,
|
||||
prop_name={'operation', 'dim_0', 'dim_1'},
|
||||
# Loaded
|
||||
props={'operation', 'dim_0', 'dim_1'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
)
|
||||
def on_pin_changed(self, props: dict, input_sockets: dict):
|
||||
def on_pin_factors_changed(self, props: dict, input_sockets: dict):
|
||||
"""Synchronize loose input sockets to match the dimension-pinning method declared in `self.operation`.
|
||||
|
||||
To "pin" an axis, a particular index must be chosen to "extract".
|
||||
One might choose axes of length 1 ("squeeze"), choose a particular index, or choose a coordinate that maps to a particular index.
|
||||
|
||||
Those last two options requires more information from the user: Which index?
|
||||
Which coordinate?
|
||||
To answer these questions, we create an appropriate loose input socket containing this data, so the user can make their decision.
|
||||
"""
|
||||
info = input_sockets['Expr']
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
if not has_info:
|
||||
return
|
||||
|
||||
# "Dimensions"|"PIN": Add/Remove Input Socket
|
||||
if props['operation'] == FilterOperation.Pin and props['dim_0'] is not None:
|
||||
# Pin Dim by-Value: Synchronize Input Socket
|
||||
## -> The user will be given a socket w/correct mathtype, unit, etc. .
|
||||
## -> Internally, this value will map to a particular index.
|
||||
if props['operation'] is FilterOperation.Pin and props['dim_0'] is not None:
|
||||
# Deduce Pinned Information
|
||||
pinned_unit = info.dim_units[props['dim_0']]
|
||||
pinned_mathtype = info.dim_mathtypes[props['dim_0']]
|
||||
pinned_physical_type = spux.PhysicalType.from_unit(pinned_unit)
|
||||
|
||||
wanted_mathtype = (
|
||||
spux.MathType.Complex
|
||||
if pinned_mathtype == spux.MathType.Complex
|
||||
|
@ -250,9 +466,11 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
)
|
||||
|
||||
# Get Current and Wanted Socket Defs
|
||||
## -> 'Value' may already exist. If not, all is well.
|
||||
current_bl_socket = self.loose_input_sockets.get('Value')
|
||||
|
||||
# Determine Whether to Declare New Loose Input SOcket
|
||||
# Determine Whether to Construct
|
||||
## -> If nothing needs to change, then nothing changes.
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
|
@ -262,22 +480,68 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
self.loose_input_sockets = {
|
||||
'Value': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Value,
|
||||
size=spux.NumberSize1D.Scalar,
|
||||
physical_type=pinned_physical_type,
|
||||
mathtype=wanted_mathtype,
|
||||
default_unit=pinned_unit,
|
||||
),
|
||||
}
|
||||
|
||||
# Pin Dim by-Index: Synchronize Input Socket
|
||||
## -> The user will be given a simple integer socket.
|
||||
elif (
|
||||
props['operation'] is FilterOperation.PinIdx and props['dim_0'] is not None
|
||||
):
|
||||
current_bl_socket = self.loose_input_sockets.get('Axis')
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.size is not spux.NumberSize1D.Scalar
|
||||
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
|
||||
or current_bl_socket.mathtype != spux.MathType.Integer
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
'Axis': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.Value,
|
||||
mathtype=spux.MathType.Integer,
|
||||
)
|
||||
}
|
||||
|
||||
# Set Dim: Synchronize Input Socket
|
||||
## -> The user must provide a (ℤ) -> ℝ array.
|
||||
## -> It must be of identical length to the replaced axis.
|
||||
elif (
|
||||
props['operation'] is FilterOperation.SetDim
|
||||
and props['dim_0'] is not None
|
||||
and info.dim_mathtypes[props['dim_0']] is spux.MathType.Integer
|
||||
and info.dim_physical_types[props['dim_0']] is spux.PhysicalType.NonPhysical
|
||||
):
|
||||
# Deduce Axis Information
|
||||
current_bl_socket = self.loose_input_sockets.get('Dim')
|
||||
if (
|
||||
current_bl_socket is None
|
||||
or current_bl_socket.active_kind != ct.FlowKind.LazyValueFunc
|
||||
or current_bl_socket.mathtype != spux.MathType.Real
|
||||
or current_bl_socket.physical_type != spux.PhysicalType.NonPhysical
|
||||
):
|
||||
self.loose_input_sockets = {
|
||||
'Dim': sockets.ExprSocketDef(
|
||||
active_kind=ct.FlowKind.LazyValueFunc,
|
||||
mathtype=spux.MathType.Real,
|
||||
physical_type=spux.PhysicalType.NonPhysical,
|
||||
show_info_columns=True,
|
||||
)
|
||||
}
|
||||
|
||||
# No Loose Value: Remove Input Sockets
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
####################
|
||||
# - Output
|
||||
# - FlowKind.Value|LazyValueFunc
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'operation', 'dim_0', 'dim_1'},
|
||||
props={'operation', 'dim_0', 'dim_1', 'slice_tuple'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
|
||||
)
|
||||
|
@ -296,82 +560,120 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
has_lazy_value_func
|
||||
and has_info
|
||||
and operation is not None
|
||||
and operation.are_dims_valid(dim_0, dim_1)
|
||||
and operation.are_dims_valid(info, dim_0, dim_1)
|
||||
):
|
||||
axis_0 = info.dim_names.index(dim_0) if dim_0 is not None else None
|
||||
axis_1 = info.dim_names.index(dim_1) if dim_1 is not None else None
|
||||
slice_tuple = (
|
||||
props['slice_tuple']
|
||||
if self.operation is FilterOperation.SliceIdx
|
||||
else None
|
||||
)
|
||||
|
||||
return lazy_value_func.compose_within(
|
||||
operation.jax_func(axis_0, axis_1),
|
||||
enclosing_func_args=[int] if operation == FilterOperation.Pin else [],
|
||||
operation.jax_func(axis_0, axis_1, slice_tuple),
|
||||
enclosing_func_args=operation.func_args,
|
||||
supports_jax=True,
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={
|
||||
'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
unit_systems={'BlenderUnits': ct.UNITS_BLENDER},
|
||||
)
|
||||
def compute_array(self, output_sockets, unit_systems) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Expr'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Expr'][ct.FlowKind.Params]
|
||||
|
||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
if has_lazy_value_func and has_params:
|
||||
unit_system = unit_systems['BlenderUnits']
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.scaled_func_args(unit_system),
|
||||
**params.scaled_func_kwargs(unit_system),
|
||||
),
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Info
|
||||
# - FlowKind.Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'dim_0', 'dim_1', 'operation'},
|
||||
input_sockets={'Expr'},
|
||||
input_socket_kinds={'Expr': ct.FlowKind.Info},
|
||||
props={
|
||||
'dim_0',
|
||||
'dim_1',
|
||||
'operation',
|
||||
'slice_tuple',
|
||||
'set_dim_symbol',
|
||||
'set_dim_active_unit',
|
||||
},
|
||||
input_sockets={'Expr', 'Dim'},
|
||||
input_socket_kinds={
|
||||
'Expr': ct.FlowKind.Info,
|
||||
'Dim': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params, ct.FlowKind.Info},
|
||||
},
|
||||
input_sockets_optional={'Dim': True},
|
||||
)
|
||||
def compute_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
def compute_info(self, props, input_sockets) -> ct.InfoFlow:
|
||||
operation = props['operation']
|
||||
info = input_sockets['Expr']
|
||||
dim_coords = input_sockets['Dim'][ct.FlowKind.LazyValueFunc]
|
||||
dim_params = input_sockets['Dim'][ct.FlowKind.Params]
|
||||
dim_info = input_sockets['Dim'][ct.FlowKind.Info]
|
||||
dim_symbol = props['set_dim_symbol']
|
||||
dim_active_unit = props['set_dim_active_unit']
|
||||
|
||||
has_info = not ct.FlowSignal.check(info)
|
||||
has_dim_coords = not ct.FlowSignal.check(dim_coords)
|
||||
has_dim_params = not ct.FlowSignal.check(dim_params)
|
||||
has_dim_info = not ct.FlowSignal.check(dim_info)
|
||||
|
||||
# Dimension(s)
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
slice_tuple = props['slice_tuple']
|
||||
if has_info and operation is not None:
|
||||
# Set Dimension: Retrieve Array
|
||||
if props['operation'] is FilterOperation.SetDim:
|
||||
if (
|
||||
has_info
|
||||
and operation is not None
|
||||
and operation.are_dims_valid(dim_0, dim_1)
|
||||
dim_0 is not None
|
||||
# Check Replaced Dimension
|
||||
and has_dim_coords
|
||||
and len(dim_coords.func_args) == 1
|
||||
and dim_coords.func_args[0] is spux.MathType.Integer
|
||||
and not dim_coords.func_kwargs
|
||||
and dim_coords.supports_jax
|
||||
# Check Params
|
||||
and has_dim_params
|
||||
and len(dim_params.func_args) == 1
|
||||
and not dim_params.func_kwargs
|
||||
# Check Info
|
||||
and has_dim_info
|
||||
):
|
||||
return operation.transform_info(info, dim_0, dim_1)
|
||||
# Retrieve Dimension Coordinate Array
|
||||
## -> It must be strictly compatible.
|
||||
values = dim_coords.func_jax(int(dim_params.func_args[0]))
|
||||
if (
|
||||
len(values.shape) != 1
|
||||
or values.shape[0] != info.dim_lens[dim_0]
|
||||
):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Transform Info w/Corrected Dimension
|
||||
## -> The existing dimension will be replaced.
|
||||
if dim_active_unit is not None:
|
||||
dim_unit = spux.unit_str_to_unit(dim_active_unit)
|
||||
else:
|
||||
dim_unit = None
|
||||
|
||||
new_dim_idx = ct.ArrayFlow(
|
||||
values=values,
|
||||
unit=dim_unit,
|
||||
)
|
||||
corrected_dim = [dim_0, (dim_symbol.name, new_dim_idx)]
|
||||
return operation.transform_info(
|
||||
info, dim_0, dim_1, corrected_dim=corrected_dim
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return operation.transform_info(info, dim_0, dim_1, slice_tuple=slice_tuple)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Params
|
||||
# - FlowKind.Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
props={'dim_0', 'dim_1', 'operation'},
|
||||
input_sockets={'Expr', 'Value'},
|
||||
input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}},
|
||||
input_sockets_optional={'Value': True},
|
||||
input_sockets={'Expr', 'Value', 'Axis'},
|
||||
input_socket_kinds={
|
||||
'Expr': {ct.FlowKind.Info, ct.FlowKind.Params},
|
||||
},
|
||||
input_sockets_optional={'Value': True, 'Axis': True},
|
||||
)
|
||||
def compute_params(self, props: dict, input_sockets: dict) -> ct.ParamsFlow:
|
||||
operation = props['operation']
|
||||
|
@ -388,20 +690,30 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
has_info
|
||||
and has_params
|
||||
and operation is not None
|
||||
and operation.are_dims_valid(dim_0, dim_1)
|
||||
and operation.are_dims_valid(info, dim_0, dim_1)
|
||||
):
|
||||
## Pinned Value
|
||||
# Retrieve Pinned Value
|
||||
pinned_value = input_sockets['Value']
|
||||
has_pinned_value = not ct.FlowSignal.check(pinned_value)
|
||||
|
||||
if props['operation'] == FilterOperation.Pin and has_pinned_value:
|
||||
pinned_axis = input_sockets['Axis']
|
||||
has_pinned_axis = not ct.FlowSignal.check(pinned_axis)
|
||||
|
||||
# Pin by-Value: Compute Nearest IDX
|
||||
## -> Presume a sorted index array to be able to use binary search.
|
||||
if props['operation'] is FilterOperation.Pin and has_pinned_value:
|
||||
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
|
||||
pinned_value, require_sorted=True
|
||||
)
|
||||
|
||||
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
|
||||
|
||||
# Pin by-Index
|
||||
if props['operation'] is FilterOperation.PinIdx and has_pinned_axis:
|
||||
return params.compose_within(enclosing_func_args=[pinned_axis])
|
||||
|
||||
return params
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import typing as typ
|
|||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import sympy as sp
|
||||
import sympy.physics.units as spu
|
||||
|
||||
|
@ -47,13 +48,20 @@ class TransformOperation(enum.StrEnum):
|
|||
InvFFT: Compute the inverse fourier transform of the input expression.
|
||||
"""
|
||||
|
||||
# Index
|
||||
# Covariant Transform
|
||||
FreqToVacWL = enum.auto()
|
||||
VacWLToFreq = enum.auto()
|
||||
|
||||
# Fold
|
||||
IntDimToComplex = enum.auto()
|
||||
DimToVec = enum.auto()
|
||||
DimsToMat = enum.auto()
|
||||
|
||||
# Fourier
|
||||
FFT1D = enum.auto()
|
||||
InvFFT1D = enum.auto()
|
||||
# Affine
|
||||
|
||||
# TODO: Affine
|
||||
## TODO
|
||||
|
||||
####################
|
||||
|
@ -63,9 +71,14 @@ class TransformOperation(enum.StrEnum):
|
|||
def to_name(value: typ.Self) -> str:
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# By Number
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: '𝑓 → λᵥ',
|
||||
TO.VacWLToFreq: 'λᵥ → 𝑓',
|
||||
# Fold
|
||||
TO.IntDimToComplex: '→ ℂ',
|
||||
TO.DimToVec: '→ Vector',
|
||||
TO.DimsToMat: '→ Matrix',
|
||||
# Fourier
|
||||
TO.FFT1D: 't → 𝑓',
|
||||
TO.InvFFT1D: '𝑓 → t',
|
||||
}[value]
|
||||
|
@ -92,7 +105,8 @@ class TransformOperation(enum.StrEnum):
|
|||
TO = TransformOperation
|
||||
operations = []
|
||||
|
||||
# Freq <-> VacWL
|
||||
# Covariant Transform
|
||||
## Freq <-> VacWL
|
||||
for dim_name in info.dim_names:
|
||||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
operations.append(TO.FreqToVacWL)
|
||||
|
@ -100,7 +114,23 @@ class TransformOperation(enum.StrEnum):
|
|||
if info.dim_physical_types[dim_name] == spux.PhysicalType.Freq:
|
||||
operations.append(TO.VacWLToFreq)
|
||||
|
||||
# 1D Fourier
|
||||
# Fold
|
||||
## (Last) Int Dim (=2) to Complex
|
||||
if len(info.dim_names) >= 1:
|
||||
last_dim_name = info.dim_names[-1]
|
||||
if info.dim_lens[last_dim_name] == 2: # noqa: PLR2004
|
||||
operations.append(TO.IntDimToComplex)
|
||||
|
||||
## To Vector
|
||||
if len(info.dim_names) >= 1:
|
||||
operations.append(TO.DimToVec)
|
||||
|
||||
## To Matrix
|
||||
if len(info.dim_names) >= 2: # noqa: PLR2004
|
||||
operations.append(TO.DimsToMat)
|
||||
|
||||
# Fourier
|
||||
## 1D Fourier
|
||||
if info.dim_names:
|
||||
last_physical_type = info.dim_physical_types[info.dim_names[-1]]
|
||||
if last_physical_type == spux.PhysicalType.Time:
|
||||
|
@ -117,9 +147,13 @@ class TransformOperation(enum.StrEnum):
|
|||
def sp_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Index
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
# Fold
|
||||
# TO.IntDimToComplex: lambda expr: expr, ## TODO: Won't work?
|
||||
TO.DimToVec: lambda expr: expr,
|
||||
TO.DimsToMat: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: sp.fourier_transform(
|
||||
expr, sim_symbols.t, sim_symbols.freq
|
||||
|
@ -133,15 +167,26 @@ class TransformOperation(enum.StrEnum):
|
|||
def jax_func(self):
|
||||
TO = TransformOperation
|
||||
return {
|
||||
# Index
|
||||
# Covariant Transform
|
||||
TO.FreqToVacWL: lambda expr: expr,
|
||||
TO.VacWLToFreq: lambda expr: expr,
|
||||
# Fold
|
||||
## -> To Complex: With a little imagination, this is a noop :)
|
||||
## -> **Requires** dims[-1] to be integer-indexed w/length of 2.
|
||||
TO.IntDimToComplex: lambda expr: expr.view(dtype=jnp.complex64).squeeze(),
|
||||
TO.DimToVec: lambda expr: expr,
|
||||
TO.DimsToMat: lambda expr: expr,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda expr: jnp.fft(expr),
|
||||
TO.InvFFT1D: lambda expr: jnp.ifft(expr),
|
||||
}[self]
|
||||
|
||||
def transform_info(self, info: ct.InfoFlow | None) -> ct.InfoFlow | None:
|
||||
def transform_info(
|
||||
self,
|
||||
info: ct.InfoFlow | None,
|
||||
data: jtyp.Shaped[jtyp.Array, '...'] | None = None,
|
||||
unit: spux.Unit | None = None,
|
||||
) -> ct.InfoFlow | None:
|
||||
TO = TransformOperation
|
||||
if not info.dim_names:
|
||||
return None
|
||||
|
@ -169,6 +214,12 @@ class TransformOperation(enum.StrEnum):
|
|||
),
|
||||
],
|
||||
),
|
||||
# Fold
|
||||
TO.IntDimToComplex: lambda: info.delete_dimension(
|
||||
info.dim_names[-1]
|
||||
).set_output_mathtype(spux.MathType.Complex),
|
||||
TO.DimToVec: lambda: info.shift_last_input,
|
||||
TO.DimsToMat: lambda: info.shift_last_input.shift_last_input,
|
||||
# Fourier
|
||||
TO.FFT1D: lambda: info.replace_dim(
|
||||
info.dim_names[-1],
|
||||
|
@ -216,7 +267,7 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
# - Properties: Expr InfoFlow
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name={'Expr'},
|
||||
|
@ -243,6 +294,9 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Properties: Operation
|
||||
####################
|
||||
operation: TransformOperation = bl_cache.BLField(
|
||||
enum_cb=lambda self, _: self.search_operations(),
|
||||
cb_depends_on={'expr_info'},
|
||||
|
@ -258,9 +312,6 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
]
|
||||
return []
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
|
@ -339,6 +390,7 @@ class TransformMathNode(base.MaxwellSimNode):
|
|||
|
||||
if has_info and operation is not None:
|
||||
transformed_info = operation.transform_info(info)
|
||||
|
||||
if transformed_info is None:
|
||||
return ct.FlowSignal.FlowPending
|
||||
return transformed_info
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from . import tidy_3d_file_importer
|
||||
from . import data_file_importer, tidy_3d_file_importer
|
||||
|
||||
BL_REGISTER = [
|
||||
*data_file_importer.BL_REGISTER,
|
||||
*tidy_3d_file_importer.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**data_file_importer.BL_NODES,
|
||||
**tidy_3d_file_importer.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,356 @@
|
|||
# blender_maxwell
|
||||
# Copyright (C) 2024 blender_maxwell Project Contributors
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import enum
|
||||
import typing as typ
|
||||
from pathlib import Path
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping as jtyp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sympy as sp
|
||||
import tidy3d as td
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
####################
|
||||
# - Data File Extensions
|
||||
####################
|
||||
_DATA_FILE_EXTS = {
|
||||
'.txt',
|
||||
'.txt.gz',
|
||||
'.csv',
|
||||
'.npy',
|
||||
}
|
||||
|
||||
|
||||
class DataFileExt(enum.StrEnum):
|
||||
Txt = enum.auto()
|
||||
TxtGz = enum.auto()
|
||||
Csv = enum.auto()
|
||||
Npy = enum.auto()
|
||||
|
||||
####################
|
||||
# - Enum Elements
|
||||
####################
|
||||
@staticmethod
|
||||
def to_name(v: typ.Self) -> str:
|
||||
return DataFileExt(v).extension
|
||||
|
||||
@staticmethod
|
||||
def to_icon(v: typ.Self) -> str:
|
||||
return ''
|
||||
|
||||
####################
|
||||
# - Computed Properties
|
||||
####################
|
||||
@property
|
||||
def extension(self) -> str:
|
||||
"""Map to the actual string extension."""
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: '.txt',
|
||||
E.TxtGz: '.txt.gz',
|
||||
E.Csv: '.csv',
|
||||
E.Npy: '.npy',
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def loader(self) -> typ.Callable[[Path], jtyp.Shaped[jtyp.Array, '...']]:
|
||||
def load_txt(path: Path):
|
||||
return jnp.asarray(np.loadtxt(path))
|
||||
|
||||
def load_csv(path: Path):
|
||||
return jnp.asarray(pd.read_csv(path).values)
|
||||
|
||||
def load_npy(path: Path):
|
||||
return jnp.load(path)
|
||||
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: load_txt,
|
||||
E.TxtGz: load_txt,
|
||||
E.Csv: load_csv,
|
||||
E.Npy: load_npy,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def loader_is_jax_compatible(self) -> bool:
|
||||
E = DataFileExt
|
||||
return {
|
||||
E.Txt: True,
|
||||
E.TxtGz: True,
|
||||
E.Csv: False,
|
||||
E.Npy: True,
|
||||
}[self]
|
||||
|
||||
####################
|
||||
# - Creation
|
||||
####################
|
||||
@staticmethod
|
||||
def from_ext(ext: str) -> typ.Self | None:
|
||||
return {
|
||||
_ext: _data_file_ext
|
||||
for _data_file_ext, _ext in {
|
||||
k: k.extension for k in list(DataFileExt)
|
||||
}.items()
|
||||
}.get(ext)
|
||||
|
||||
@staticmethod
|
||||
def from_path(path: Path) -> typ.Self | None:
|
||||
if DataFileExt.is_path_compatible(path):
|
||||
data_file_ext = DataFileExt.from_ext(''.join(path.suffixes))
|
||||
if data_file_ext is not None:
|
||||
return data_file_ext
|
||||
|
||||
msg = f'DataFileExt: Path "{path}" is compatible, but could not find valid extension'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Compatibility
|
||||
####################
|
||||
@staticmethod
|
||||
def is_ext_compatible(ext: str):
|
||||
return ext in _DATA_FILE_EXTS
|
||||
|
||||
@staticmethod
|
||||
def is_path_compatible(path: Path):
|
||||
return path.is_file() and DataFileExt.is_ext_compatible(''.join(path.suffixes))
|
||||
|
||||
|
||||
####################
|
||||
# - Node
|
||||
####################
|
||||
class DataFileImporterNode(base.MaxwellSimNode):
|
||||
node_type = ct.NodeType.DataFileImporter
|
||||
bl_label = 'Data File Importer'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'File Path': sockets.FilePathSocketDef(),
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.LazyValueFunc),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name={'File Path'},
|
||||
input_sockets={'File Path'},
|
||||
input_socket_kinds={'File Path': ct.FlowKind.Value},
|
||||
input_sockets_optional={'File Path': True},
|
||||
)
|
||||
def on_input_exprs_changed(self, input_sockets) -> None: # noqa: D102
|
||||
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
|
||||
|
||||
has_file_path = ct.FlowSignal.check_single(
|
||||
input_sockets['File Path'], ct.FlowSignal.FlowPending
|
||||
)
|
||||
|
||||
if has_file_path:
|
||||
self.file_path = bl_cache.Signal.InvalidateCache
|
||||
|
||||
@bl_cache.cached_bl_property()
|
||||
def file_path(self) -> Path:
|
||||
"""Retrieve the input file path."""
|
||||
file_path = self._compute_input(
|
||||
'File Path', kind=ct.FlowKind.Value, optional=True
|
||||
)
|
||||
has_file_path = not ct.FlowSignal.check(file_path)
|
||||
if has_file_path:
|
||||
return file_path
|
||||
|
||||
return None
|
||||
|
||||
@bl_cache.cached_bl_property(depends_on={'file_path'})
|
||||
def data_file_ext(self) -> DataFileExt | None:
|
||||
"""Retrieve the file extension by concatenating all suffixes."""
|
||||
if self.file_path is not None:
|
||||
return DataFileExt.from_path(self.file_path)
|
||||
return None
|
||||
|
||||
####################
|
||||
# - Output Info
|
||||
####################
|
||||
@bl_cache.cached_bl_property(depends_on={'file_path'})
|
||||
def expr_info(self) -> ct.InfoFlow | None:
|
||||
"""Retrieve the output expression's `InfoFlow`."""
|
||||
info = self.compute_output('Expr', kind=ct.FlowKind.Info)
|
||||
has_info = not ct.FlowKind.check(info)
|
||||
if has_info:
|
||||
return info
|
||||
return None
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
"""Show the extracted file name (w/extension) in the node's header label.
|
||||
|
||||
Notes:
|
||||
Called by Blender to determine the text to place in the node's header.
|
||||
"""
|
||||
if self.file_path is not None:
|
||||
return 'Load File: ' + self.file_path.name
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
"""Show information about the loaded file."""
|
||||
if self.data_file_ext is not None:
|
||||
box = layout.box()
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text='Data File')
|
||||
|
||||
row = box.row()
|
||||
row.alignment = 'CENTER'
|
||||
row.label(text=self.file_path.name)
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
pass
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='File Path',
|
||||
input_sockets={'File Path'},
|
||||
)
|
||||
def on_file_changed(self, input_sockets) -> None:
|
||||
pass
|
||||
|
||||
####################
|
||||
# - FlowKind.Array|LazyValueFunc
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
input_sockets={'File Path'},
|
||||
)
|
||||
def compute_func(self, input_sockets: dict) -> td.Simulation:
|
||||
"""Declare a lazy, composable function that returns the loaded data.
|
||||
|
||||
Returns:
|
||||
A completely empty `ParamsFlow`, ready to be composed.
|
||||
"""
|
||||
file_path = input_sockets['File Path']
|
||||
|
||||
has_file_path = not ct.FlowSignal.check(input_sockets['File Path'])
|
||||
|
||||
if has_file_path:
|
||||
data_file_ext = DataFileExt.from_path(file_path)
|
||||
if data_file_ext is not None:
|
||||
# Jax Compatibility: Lazy Data Loading
|
||||
## -> Delay loading of data from file as long as we can.
|
||||
if data_file_ext.loader_is_jax_compatible:
|
||||
return ct.LazyValueFuncFlow(
|
||||
func=lambda: data_file_ext.loader(file_path),
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
# No Jax Compatibility: Eager Data Loading
|
||||
## -> Load the data now and bind it.
|
||||
data = data_file_ext.loader(file_path)
|
||||
return ct.LazyValueFuncFlow(func=lambda: data, supports_jax=True)
|
||||
return ct.FlowSignal.FlowPending
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - FlowKind.Params|Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Params,
|
||||
)
|
||||
def compute_params(self) -> ct.ParamsFlow:
|
||||
"""Declare an empty `Data:Params`, to indicate the start of a function-composition pipeline.
|
||||
|
||||
Returns:
|
||||
A completely empty `ParamsFlow`, ready to be composed.
|
||||
"""
|
||||
return ct.ParamsFlow()
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Info,
|
||||
output_sockets={'Expr'},
|
||||
output_socket_kinds={'Expr': ct.FlowKind.LazyValueFunc},
|
||||
)
|
||||
def compute_info(self, output_sockets) -> ct.InfoFlow:
|
||||
"""Declare an `InfoFlow` based on the data shape.
|
||||
|
||||
This currently requires computing the data.
|
||||
Note, however, that the incremental cache causes this computation only to happen once when a file is selected.
|
||||
|
||||
Returns:
|
||||
A completely empty `ParamsFlow`, ready to be composed.
|
||||
"""
|
||||
expr = output_sockets['Expr']
|
||||
|
||||
has_expr_func = not ct.FlowSignal.check(expr)
|
||||
|
||||
if has_expr_func:
|
||||
data = expr.func_jax()
|
||||
|
||||
# Deduce Dimensionality
|
||||
_shape = data.shape
|
||||
shape = _shape if _shape is not None else ()
|
||||
dim_names = [f'a{i}' for i in range(len(shape))]
|
||||
|
||||
# Return InfoFlow
|
||||
## -> TODO: How to interpret the data should be user-defined.
|
||||
## -> -- This may require those nice dynamic symbols.
|
||||
return ct.InfoFlow(
|
||||
dim_names=dim_names, ## TODO: User
|
||||
dim_idx={
|
||||
dim_name: ct.LazyArrayRangeFlow(
|
||||
start=sp.S(0), ## TODO: User
|
||||
stop=sp.S(shape[i] - 1), ## TODO: User
|
||||
steps=shape[dim_names.index(dim_name)],
|
||||
unit=None, ## TODO: User
|
||||
)
|
||||
for i, dim_name in enumerate(dim_names)
|
||||
},
|
||||
output_name='_',
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real, ## TODO: User
|
||||
output_unit=None, ## TODO: User
|
||||
)
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
DataFileImporterNode,
|
||||
]
|
||||
BL_NODES = {
|
||||
ct.NodeType.DataFileImporter: (ct.NodeCategory.MAXWELLSIM_INPUTS_FILEIMPORTERS)
|
||||
}
|
|
@ -234,7 +234,7 @@ def plot_curves_2d(
|
|||
y_unit = info.output_unit
|
||||
|
||||
for category in range(data.shape[1]):
|
||||
ax.plot(data[:, 0], data[:, 1])
|
||||
ax.plot(info.dim_idx_arrays[0], data[:, category])
|
||||
|
||||
ax.set_title('2D Curves')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
|
@ -250,8 +250,9 @@ def plot_filled_curves_2d(
|
|||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
|
||||
ax.set_title('2D Curves')
|
||||
shared_x_idx = info.dim_idx_arrays[0]
|
||||
ax.fill_between(shared_x_idx, data[:, 0], shared_x_idx, data[:, 1])
|
||||
ax.set_title('2D Filled Curves')
|
||||
ax.set_xlabel(f'{x_name}' + (f'({x_unit})' if x_unit is not None else ''))
|
||||
ax.set_ylabel(f'{y_name}' + (f'({y_unit})' if y_unit is not None else ''))
|
||||
|
||||
|
|
|
@ -109,6 +109,10 @@ class SimSymbol:
|
|||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.sim_node_name.name
|
||||
|
||||
@property
|
||||
def domain(self) -> sp.Interval | sp.Set:
|
||||
"""Return the domain of valid values for the symbol.
|
||||
|
@ -235,6 +239,10 @@ class CommonSimSymbol(enum.StrEnum):
|
|||
####################
|
||||
# - Properties
|
||||
####################
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.sim_symbol.name
|
||||
|
||||
@property
|
||||
def sim_symbol_name(self) -> str:
|
||||
SSN = SimSymbolName
|
||||
|
|
Loading…
Reference in New Issue