feat: Implemented operate math node.
parent
7d944a704e
commit
b2a7eefb45
4
TODO.md
4
TODO.md
|
@ -14,6 +14,8 @@
|
|||
- [x] Extract
|
||||
- [x] Viz
|
||||
- [x] Math / Map Math
|
||||
- [ ] Remove "By x" socket set let socket sets only be "Function"/"Expr"; then add a dynamic enum underneath to select "By x" based on data support.
|
||||
- [ ] Filter the operations based on data support, ex. use positive-definiteness to guide cholesky.
|
||||
- [x] Math / Filter Math
|
||||
- [ ] Math / Reduce Math
|
||||
- [ ] Math / Operate Math
|
||||
|
@ -34,8 +36,6 @@
|
|||
- [x] Constants / Blender Constant
|
||||
|
||||
- [ ] Web / Tidy3D Web Importer
|
||||
- [ ] Change to output only a `FilePath`, which can be plugged into a Tidy3D File Importer.
|
||||
- [ ] Implement caching, such that the file will only download if the file doesn't already exist.
|
||||
- [ ] Have a visual indicator for the current download status, with a manual re-download button.
|
||||
|
||||
- [x] File Import / Material Import
|
||||
|
|
|
@ -68,6 +68,9 @@ class FlowKind(enum.StrEnum):
|
|||
if kind == cls.LazyArrayRange:
|
||||
return value.rescale_to_unit(unit_system[socket_type])
|
||||
|
||||
if kind == cls.Params:
|
||||
return value.rescale_to_unit(unit_system[socket_type])
|
||||
|
||||
msg = 'Tried to scale unknown kind'
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -322,6 +325,28 @@ class LazyValueFuncFlow:
|
|||
supports_jax: bool = False
|
||||
supports_numba: bool = False
|
||||
|
||||
# Merging
|
||||
def __or__(
|
||||
self,
|
||||
other: typ.Self,
|
||||
):
|
||||
return LazyValueFuncFlow(
|
||||
func=lambda *args, **kwargs: (
|
||||
self.func(
|
||||
*list(args[: len(self.func_args)]),
|
||||
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||
),
|
||||
other.func(
|
||||
*list(args[len(self.func_args) :]),
|
||||
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
|
||||
),
|
||||
),
|
||||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
supports_jax=self.supports_jax and other.supports_jax,
|
||||
supports_numba=self.supports_numba and other.supports_numba,
|
||||
)
|
||||
|
||||
# Composition
|
||||
def compose_within(
|
||||
self,
|
||||
|
@ -691,10 +716,21 @@ class ParamsFlow:
|
|||
func_args: list[typ.Any] = dataclasses.field(default_factory=list)
|
||||
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: typ.Self,
|
||||
):
|
||||
return ParamsFlow(
|
||||
func_args=self.func_args + other.func_args,
|
||||
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||
)
|
||||
|
||||
def compose_within(
|
||||
self,
|
||||
enclosing_func_args: list[tuple[type]] = (),
|
||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||
enclosing_func_arg_units: dict[str, type] = MappingProxyType({}),
|
||||
enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}),
|
||||
) -> typ.Self:
|
||||
return ParamsFlow(
|
||||
func_args=self.func_args + list(enclosing_func_args),
|
||||
|
@ -718,26 +754,133 @@ class InfoFlow:
|
|||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||
|
||||
@functools.cached_property
|
||||
def dim_mathtypes(self) -> dict[str, int]:
|
||||
def dim_mathtypes(self) -> dict[str, spux.MathType]:
|
||||
return {
|
||||
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
|
||||
}
|
||||
|
||||
@functools.cached_property
|
||||
def dim_units(self) -> dict[str, int]:
|
||||
def dim_units(self) -> dict[str, spux.Unit]:
|
||||
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
|
||||
|
||||
@functools.cached_property
|
||||
def dim_idx_arrays(self) -> list[ArrayFlow]:
|
||||
def dim_idx_arrays(self) -> list[jax.Array]:
|
||||
return [
|
||||
dim_idx.realize().values
|
||||
if isinstance(dim_idx, LazyArrayRangeFlow)
|
||||
else dim_idx.values
|
||||
for dim_idx in self.dim_idx.values()
|
||||
]
|
||||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||
|
||||
# Output Information
|
||||
output_names: list[str] = dataclasses.field(default_factory=list)
|
||||
output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict)
|
||||
output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict)
|
||||
output_name: str = dataclasses.field(default_factory=list)
|
||||
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
|
||||
output_mathtype: spux.MathType = dataclasses.field()
|
||||
output_unit: spux.Unit | None = dataclasses.field()
|
||||
|
||||
# Pinned Dimension Information
|
||||
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
|
||||
|
||||
####################
|
||||
# - Methods
|
||||
####################
|
||||
def delete_dimension(self, dim_name: str) -> typ.Self:
|
||||
"""Delete a dimension."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=[
|
||||
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
|
||||
],
|
||||
dim_idx={
|
||||
_dim_name: dim_idx
|
||||
for _dim_name, dim_idx in self.dim_idx.items()
|
||||
if _dim_name != dim_name
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
||||
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
|
||||
"""Delete a dimension."""
|
||||
|
||||
# Compute Swapped Dimension Name List
|
||||
def name_swapper(dim_name):
|
||||
return (
|
||||
dim_name
|
||||
if dim_name not in [dim_0_name, dim_1_name]
|
||||
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
|
||||
)
|
||||
|
||||
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
|
||||
|
||||
# Compute Info
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=dim_names,
|
||||
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
||||
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
|
||||
"""Set the MathType of a particular output name."""
|
||||
return InfoFlow(
|
||||
dim_names=self.dim_names,
|
||||
dim_idx=self.dim_idx,
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=self.output_shape,
|
||||
output_mathtype=output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
||||
def collapse_output(
|
||||
self,
|
||||
collapsed_name: str,
|
||||
collapsed_mathtype: spux.MathType,
|
||||
collapsed_unit: spux.Unit,
|
||||
) -> typ.Self:
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names,
|
||||
dim_idx=self.dim_idx,
|
||||
output_name=collapsed_name,
|
||||
output_shape=None,
|
||||
output_mathtype=collapsed_mathtype,
|
||||
output_unit=collapsed_unit,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def shift_last_input(self):
|
||||
"""Shift the last input dimension to the output."""
|
||||
return InfoFlow(
|
||||
# Dimensions
|
||||
dim_names=self.dim_names[:-1],
|
||||
dim_idx={
|
||||
dim_name: dim_idx
|
||||
for dim_name, dim_idx in self.dim_idx.items()
|
||||
if dim_name != self.dim_names[-1]
|
||||
},
|
||||
# Outputs
|
||||
output_name=self.output_name,
|
||||
output_shape=(
|
||||
(self.dim_lens[self.dim_names[-1]],)
|
||||
if self.output_shape is None
|
||||
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
|
||||
),
|
||||
output_mathtype=self.output_mathtype,
|
||||
output_unit=self.output_unit,
|
||||
)
|
||||
|
|
|
@ -15,6 +15,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
|
|||
FilterMath = enum.auto()
|
||||
ReduceMath = enum.auto()
|
||||
OperateMath = enum.auto()
|
||||
TransformMath = enum.auto()
|
||||
|
||||
# Inputs
|
||||
WaveConstant = enum.auto()
|
||||
|
|
|
@ -260,7 +260,9 @@ SOCKET_UNITS = {
|
|||
}
|
||||
|
||||
|
||||
def unit_to_socket_type(unit: spux.Unit) -> ST:
|
||||
def unit_to_socket_type(
|
||||
unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None
|
||||
) -> ST:
|
||||
"""Returns a SocketType that accepts the given unit.
|
||||
|
||||
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
|
||||
|
@ -269,6 +271,14 @@ def unit_to_socket_type(unit: spux.Unit) -> ST:
|
|||
Returns:
|
||||
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
|
||||
"""
|
||||
if unit is None and fallback_mathtype is not None:
|
||||
return {
|
||||
spux.MathType.Integer: ST.IntegerNumber,
|
||||
spux.MathType.Rational: ST.RationalNumber,
|
||||
spux.MathType.Real: ST.RealNumber,
|
||||
spux.MathType.Complex: ST.ComplexNumber,
|
||||
}[fallback_mathtype]
|
||||
|
||||
for socket_type, _units in SOCKET_UNITS.items():
|
||||
if unit in _units['values'].values():
|
||||
return socket_type
|
||||
|
|
|
@ -215,6 +215,15 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
has_sim_data = self.sim_data_monitor_nametype is not None
|
||||
has_monitor_data = self.monitor_data_components is not None
|
||||
|
||||
if has_sim_data or has_monitor_data:
|
||||
return f'Extract: {self.extract_filter}'
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||
"""Draw node properties in the node.
|
||||
|
||||
|
@ -223,43 +232,6 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
"""
|
||||
col.prop(self, self.blfields['extract_filter'], text='')
|
||||
|
||||
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||
"""Draw dynamic information in the node, for user consideration.
|
||||
|
||||
Parameters:
|
||||
col: UI target for drawing.
|
||||
"""
|
||||
has_sim_data = self.sim_data_monitor_nametype is not None
|
||||
has_monitor_data = self.monitor_data_components is not None
|
||||
|
||||
if has_sim_data or has_monitor_data:
|
||||
# Header
|
||||
row = col.row()
|
||||
row.alignment = 'CENTER'
|
||||
if has_sim_data:
|
||||
row.label(text=f'{len(self.sim_data_monitor_nametype)} Monitors')
|
||||
elif has_monitor_data:
|
||||
row.label(text=f'{self.monitor_data_type} Monitor Data')
|
||||
|
||||
# Monitor Data Contents
|
||||
## TODO: More compact double-split
|
||||
## TODO: Output shape data.
|
||||
## TODO: Local ENUM_MANY tabs for visible column selection?
|
||||
row = col.row()
|
||||
box = row.box()
|
||||
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
|
||||
if has_sim_data:
|
||||
for (
|
||||
monitor_name,
|
||||
monitor_type,
|
||||
) in self.sim_data_monitor_nametype.items():
|
||||
grid.label(text=monitor_name)
|
||||
grid.label(text=monitor_type.replace('Data', ''))
|
||||
elif has_monitor_data:
|
||||
for component_name in self.monitor_data_components:
|
||||
grid.label(text=component_name)
|
||||
grid.label(text=self.monitor_data_type)
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
|
@ -416,9 +388,8 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
else:
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
info_output_names = {
|
||||
'output_names': [props['extract_filter']],
|
||||
}
|
||||
info_output_name = props['extract_filter']
|
||||
info_output_shape = None
|
||||
|
||||
# Compute InfoFlow from XArray
|
||||
## XYZF: Field / Permittivity / FieldProjectionCartesian
|
||||
|
@ -442,13 +413,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
|
||||
output_units={
|
||||
props['extract_filter']: spu.volt / spu.micrometer
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Complex,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if props['monitor_data_type'] == 'Field'
|
||||
else None
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
## XYZT: FieldTime
|
||||
|
@ -468,17 +440,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
|
||||
output_units={
|
||||
props['extract_filter']: (
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Complex,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if props['extract_filter'].startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
)
|
||||
if props['monitor_data_type'] == 'Field'
|
||||
else None
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
## F: Flux
|
||||
|
@ -492,9 +461,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
||||
output_units={props['extract_filter']: spu.watt},
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=spu.watt,
|
||||
)
|
||||
|
||||
## T: FluxTime
|
||||
|
@ -508,9 +478,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
||||
output_units={props['extract_filter']: spu.watt},
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=spu.watt,
|
||||
)
|
||||
|
||||
## RThetaPhiF: FieldProjectionAngle
|
||||
|
@ -537,15 +508,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
||||
output_units={
|
||||
props['extract_filter']: (
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if props['extract_filter'].startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
## UxUyRF: FieldProjectionKSpace
|
||||
|
@ -570,15 +540,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
||||
output_units={
|
||||
props['extract_filter']: (
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if props['extract_filter'].startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
## OrderxOrderyF: Diffraction
|
||||
|
@ -600,15 +569,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
|||
is_sorted=True,
|
||||
),
|
||||
},
|
||||
**info_output_names,
|
||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
||||
output_units={
|
||||
props['extract_filter']: (
|
||||
output_name=props['extract_filter'],
|
||||
output_shape=None,
|
||||
output_mathtype=spux.MathType.Real,
|
||||
output_unit=(
|
||||
spu.volt / spu.micrometer
|
||||
if props['extract_filter'].startswith('E')
|
||||
else spu.ampere / spu.micrometer
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from . import filter_math, map_math, operate_math, reduce_math
|
||||
from . import filter_math, map_math, operate_math, reduce_math, transform_math
|
||||
|
||||
BL_REGISTER = [
|
||||
*map_math.BL_REGISTER,
|
||||
*filter_math.BL_REGISTER,
|
||||
*reduce_math.BL_REGISTER,
|
||||
*operate_math.BL_REGISTER,
|
||||
*transform_math.BL_REGISTER,
|
||||
]
|
||||
BL_NODES = {
|
||||
**map_math.BL_NODES,
|
||||
**filter_math.BL_NODES,
|
||||
**reduce_math.BL_NODES,
|
||||
**operate_math.BL_NODES,
|
||||
**transform_math.BL_NODES,
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Declares `FilterMathNode`."""
|
||||
|
||||
import enum
|
||||
import typing as typ
|
||||
|
||||
|
@ -15,11 +17,21 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
class FilterMathNode(base.MaxwellSimNode):
|
||||
"""Reduces the dimensionality of data.
|
||||
r"""Applies a function that operates on the shape of the array.
|
||||
|
||||
The shape, type, and interpretation of the input/output data is dynamically shown.
|
||||
|
||||
# Socket Sets
|
||||
## Dimensions
|
||||
Alter the dimensions of the array.
|
||||
|
||||
## Interpret
|
||||
Only alter the interpretation of the array data, which guides what it can be used for.
|
||||
|
||||
These operations are **zero cost**, since the data itself is untouched.
|
||||
|
||||
Attributes:
|
||||
operation: Operation to apply to the input.
|
||||
dim: Dims to use when filtering data
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.FilterMath
|
||||
|
@ -29,8 +41,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'By Dim': {},
|
||||
'By Dim Value': {},
|
||||
'Interpret': {},
|
||||
'Dimensions': {},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
|
@ -43,10 +55,17 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||
)
|
||||
|
||||
dim: enum.Enum = bl_cache.BLField(
|
||||
# Dimension Selection
|
||||
dim_0: enum.Enum = bl_cache.BLField(
|
||||
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
||||
)
|
||||
dim_1: enum.Enum = bl_cache.BLField(
|
||||
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
||||
)
|
||||
|
||||
####################
|
||||
# - Computed
|
||||
####################
|
||||
@property
|
||||
def data_info(self) -> ct.InfoFlow | None:
|
||||
info = self._compute_input('Data', kind=ct.FlowKind.Info)
|
||||
|
@ -60,87 +79,119 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
####################
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
items = []
|
||||
if self.active_socket_set == 'By Dim':
|
||||
if self.active_socket_set == 'Interpret':
|
||||
items += [
|
||||
('SQUEEZE', 'del a | #=1', 'Squeeze'),
|
||||
('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'),
|
||||
('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'),
|
||||
]
|
||||
if self.active_socket_set == 'By Dim Value':
|
||||
elif self.active_socket_set == 'Dimensions':
|
||||
items += [
|
||||
('FIX', 'del a | i≈v', 'Fix Coordinate'),
|
||||
('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'),
|
||||
(
|
||||
'PIN',
|
||||
'pinₐ ≈v',
|
||||
'Remove a len(n) dimension by selecting an index',
|
||||
),
|
||||
('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'),
|
||||
]
|
||||
|
||||
return [(*item, '', i) for i, item in enumerate(items)]
|
||||
|
||||
####################
|
||||
# - Dim Search
|
||||
# - Dimensions Search
|
||||
####################
|
||||
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||
if self.data_info is not None:
|
||||
dims = [
|
||||
(dim_name, dim_name, dim_name, '', i)
|
||||
for i, dim_name in enumerate(self.data_info.dim_names)
|
||||
]
|
||||
|
||||
# Squeeze: Dimension Must Have Length=1
|
||||
## We must also correct the "NUMBER" of the enum.
|
||||
if self.operation == 'SQUEEZE':
|
||||
filtered_dims = [
|
||||
dim for dim in dims if self.data_info.dim_lens[dim[0]] == 1
|
||||
]
|
||||
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
|
||||
|
||||
return dims
|
||||
if self.data_info is None:
|
||||
return []
|
||||
|
||||
if self.operation == 'PIN_LEN_ONE':
|
||||
dims = [
|
||||
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
|
||||
for dim_name in self.data_info.dim_names
|
||||
if self.data_info.dim_lens[dim_name] == 1
|
||||
]
|
||||
elif self.operation in ['PIN', 'SWAP']:
|
||||
dims = [
|
||||
(dim_name, dim_name, f'Dimension "{dim_name}"')
|
||||
for dim_name in self.data_info.dim_names
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(*dim, '', i) for i, dim in enumerate(dims)]
|
||||
|
||||
####################
|
||||
# - UI
|
||||
####################
|
||||
def draw_label(self):
|
||||
labels = {
|
||||
'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)',
|
||||
'PIN': lambda: f'Filter: Pin {self.dim_0}',
|
||||
'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
|
||||
'DIM_TO_VEC': lambda: 'Filter: -> Vector',
|
||||
'DIMS_TO_MAT': lambda: 'Filter: -> Matrix',
|
||||
}
|
||||
|
||||
if (label := labels.get(self.operation)) is not None:
|
||||
return label()
|
||||
|
||||
return self.bl_label
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
if self.data_info is not None and self.data_info.dim_names:
|
||||
layout.prop(self, self.blfields['dim'], text='')
|
||||
|
||||
if self.active_socket_set == 'Dimensions':
|
||||
if self.operation in ['PIN_LEN_ONE', 'PIN']:
|
||||
layout.prop(self, self.blfields['dim_0'], text='')
|
||||
|
||||
if self.operation == 'SWAP':
|
||||
row = layout.row(align=True)
|
||||
row.prop(self, self.blfields['dim_0'], text='')
|
||||
row.prop(self, self.blfields['dim_1'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
socket_name='Data',
|
||||
prop_name='active_socket_set',
|
||||
run_on_init=True,
|
||||
input_sockets={'Data'},
|
||||
)
|
||||
def on_any_change(self, input_sockets: dict):
|
||||
if all(
|
||||
not ct.FlowSignal.check_single(
|
||||
input_socket_value, ct.FlowSignal.FlowPending
|
||||
)
|
||||
for input_socket_value in input_sockets.values()
|
||||
):
|
||||
def on_socket_set_changed(self):
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
self.dim = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name='Data',
|
||||
prop_name={'active_socket_set', 'operation'},
|
||||
run_on_init=True,
|
||||
# Loaded
|
||||
props={'operation'},
|
||||
)
|
||||
def on_any_change(self, props: dict) -> None:
|
||||
self.dim_0 = bl_cache.Signal.ResetEnumItems
|
||||
self.dim_1 = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
socket_name='Data',
|
||||
prop_name='dim',
|
||||
prop_name={'dim_0', 'dim_1', 'operation'},
|
||||
## run_on_init: Implicitly triggered.
|
||||
props={'active_socket_set', 'dim'},
|
||||
props={'operation', 'dim_0', 'dim_1'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||
)
|
||||
def on_dim_change(self, props: dict, input_sockets: dict):
|
||||
if input_sockets['Data'] == ct.FlowSignal.FlowPending:
|
||||
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||
if not has_data:
|
||||
return
|
||||
|
||||
# Add/Remove Input Socket "Value"
|
||||
if (
|
||||
not ct.Flowsignal.check(input_sockets['Data'])
|
||||
and props['active_socket_set'] == 'By Dim Value'
|
||||
and props['dim'] != 'NONE'
|
||||
):
|
||||
# "Dimensions"|"PIN": Add/Remove Input Socket
|
||||
if props['operation'] == 'PIN' and props['dim_0'] != 'NONE':
|
||||
# Get Current and Wanted Socket Defs
|
||||
current_bl_socket = self.loose_input_sockets.get('Value')
|
||||
wanted_socket_def = sockets.SOCKET_DEFS[
|
||||
ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit)
|
||||
ct.unit_to_socket_type(
|
||||
input_sockets['Data'].dim_idx[props['dim_0']].unit
|
||||
)
|
||||
]
|
||||
|
||||
# Determine Whether to Declare New Loose Input SOcket
|
||||
|
@ -151,7 +202,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
):
|
||||
self.loose_input_sockets = {
|
||||
'Value': wanted_socket_def(),
|
||||
} ## TODO: Can we do the boilerplate in base.py?
|
||||
}
|
||||
elif self.loose_input_sockets:
|
||||
self.loose_input_sockets = {}
|
||||
|
||||
|
@ -161,40 +212,51 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'active_socket_set', 'operation', 'dim'},
|
||||
props={'operation', 'dim_0', 'dim_1'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
# Retrieve Inputs
|
||||
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||
info = input_sockets['Data'][ct.FlowKind.Info]
|
||||
|
||||
# Check Flow
|
||||
if (
|
||||
any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func])
|
||||
or props['operation'] == 'NONE'
|
||||
):
|
||||
if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Compute Bound/Free Parameters
|
||||
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
|
||||
axis = info.dim_names.index(props['dim'])
|
||||
# Compute Function Arguments
|
||||
operation = props['operation']
|
||||
if operation == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Select Function
|
||||
filter_func: typ.Callable[[jax.Array], jax.Array] = {
|
||||
'By Dim': {'SQUEEZE': lambda data: jnp.squeeze(data, axis)},
|
||||
'By Dim Value': {
|
||||
'FIX': lambda data, fixed_axis_idx: jnp.take(
|
||||
data, fixed_axis_idx, axis=axis
|
||||
)
|
||||
},
|
||||
}[props['active_socket_set']][props['operation']]
|
||||
## Dimension(s)
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
if operation == 'SWAP' and dim_1 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
## Axis/Axes
|
||||
axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None
|
||||
axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None
|
||||
|
||||
# Compose Output Function
|
||||
filter_func = {
|
||||
# Dimensions
|
||||
'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0),
|
||||
'PIN': lambda data, fixed_axis_idx: jnp.take(
|
||||
data, fixed_axis_idx, axis=axis_0
|
||||
),
|
||||
'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1),
|
||||
# Interpret
|
||||
'DIM_TO_VEC': lambda data: data,
|
||||
'DIMS_TO_MAT': lambda data: data,
|
||||
}[props['operation']]
|
||||
|
||||
# Compose Function for Output
|
||||
return lazy_value_func.compose_within(
|
||||
filter_func,
|
||||
enclosing_func_args=func_args,
|
||||
enclosing_func_args=[int] if operation == 'PIN' else [],
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
@ -207,7 +269,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
},
|
||||
)
|
||||
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||
# Retrieve Inputs
|
||||
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||
|
||||
|
@ -215,57 +276,54 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Compute Array
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs),
|
||||
unit=None, ## TODO: Unit Propagation
|
||||
unit=None,
|
||||
)
|
||||
|
||||
####################
|
||||
# - Compute Auxiliary: Info / Params
|
||||
# - Compute Auxiliary: Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'active_socket_set', 'dim', 'operation'},
|
||||
props={'dim_0', 'dim_1', 'operation'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
# Retrieve Inputs
|
||||
info = input_sockets['Data']
|
||||
|
||||
# Check Flow
|
||||
if ct.FlowSignal.check(info) or props['dim'] == 'NONE':
|
||||
if ct.FlowSignal.check(info):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Compute Information
|
||||
## Compute Info w/By-Operation Change to Dimensions
|
||||
axis = info.dim_names.index(props['dim'])
|
||||
# Collect Information
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
|
||||
if (props['active_socket_set'], props['operation']) in [
|
||||
('By Dim', 'SQUEEZE'),
|
||||
('By Dim Value', 'FIX'),
|
||||
] and info.dim_names:
|
||||
return ct.InfoFlow(
|
||||
dim_names=info.dim_names[:axis] + info.dim_names[axis + 1 :],
|
||||
dim_idx={
|
||||
dim_name: dim_idx
|
||||
for dim_name, dim_idx in info.dim_idx.items()
|
||||
if dim_name != props['dim']
|
||||
},
|
||||
output_names=info.output_names,
|
||||
output_mathtypes=info.output_mathtypes,
|
||||
output_units=info.output_units,
|
||||
)
|
||||
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
msg = f'Active socket set {props["active_socket_set"]} and operation {props["operation"]} don\'t have an InfoFlow defined'
|
||||
raise RuntimeError(msg)
|
||||
return {
|
||||
# Dimensions
|
||||
'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0),
|
||||
'PIN': lambda: info.delete_dimension(dim_0),
|
||||
'SWAP': lambda: info.swap_dimensions(dim_0, dim_1),
|
||||
# Interpret
|
||||
'DIM_TO_VEC': lambda: info.shift_last_input,
|
||||
'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input,
|
||||
}[props['operation']]()
|
||||
|
||||
####################
|
||||
# - Compute Auxiliary: Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Params,
|
||||
props={'active_socket_set', 'dim', 'operation'},
|
||||
props={'dim_0', 'dim_1', 'operation'},
|
||||
input_sockets={'Data', 'Value'},
|
||||
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}},
|
||||
input_sockets_optional={'Value': True},
|
||||
|
@ -273,35 +331,33 @@ class FilterMathNode(base.MaxwellSimNode):
|
|||
def compute_composed_params(
|
||||
self, props: dict, input_sockets: dict
|
||||
) -> ct.ParamsFlow:
|
||||
# Retrieve Inputs
|
||||
info = input_sockets['Data'][ct.FlowKind.Info]
|
||||
params = input_sockets['Data'][ct.FlowKind.Params]
|
||||
|
||||
# Check Flow
|
||||
if any(ct.FlowSignal.check(inp) for inp in [info, params]):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Compute Composed Parameters
|
||||
## -> Only operations that add parameters.
|
||||
## -> A dimension must be selected.
|
||||
## -> There must be an input value.
|
||||
if (
|
||||
(props['active_socket_set'], props['operation'])
|
||||
in [
|
||||
('By Dim Value', 'FIX'),
|
||||
]
|
||||
and props['dim'] != 'NONE'
|
||||
and not ct.FlowSignal.check(input_sockets['Value'])
|
||||
):
|
||||
# Compute IDX Corresponding to Coordinate Value
|
||||
## -> Each dimension declares a unit-aware real number at each index.
|
||||
## -> "Value" is a unit-aware real number from loose input socket.
|
||||
## -> This finds the dimensional index closest to "Value".
|
||||
## Total Effect: Indexing by a unit-aware real number.
|
||||
nearest_idx_to_value = info.dim_idx[props['dim']].nearest_idx_of(
|
||||
# Collect Information
|
||||
## Dimensions
|
||||
dim_0 = props['dim_0']
|
||||
dim_1 = props['dim_1']
|
||||
|
||||
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
## Pinned Value
|
||||
pinned_value = input_sockets['Value']
|
||||
has_pinned_value = not ct.FlowSignal.check(pinned_value)
|
||||
|
||||
if props['operation'] == 'PIN' and has_pinned_value:
|
||||
# Compute IDX Corresponding to Dimension Index
|
||||
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
|
||||
input_sockets['Value'], require_sorted=True
|
||||
)
|
||||
|
||||
# Compose Parameters
|
||||
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
|
||||
|
||||
return params
|
||||
|
|
|
@ -138,6 +138,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
('SQ', 'v²', 'v^2 (by el)'),
|
||||
('SQRT', '√v', 'sqrt(v) (by el)'),
|
||||
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
|
||||
None,
|
||||
# Trigonometry
|
||||
('COS', 'cos v', 'cos(v) (by el)'),
|
||||
('SIN', 'sin v', 'sin(v) (by el)'),
|
||||
|
@ -148,6 +149,7 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
]
|
||||
elif self.active_socket_set in 'By Vector':
|
||||
items = [
|
||||
# Vector -> Number
|
||||
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
|
||||
]
|
||||
elif self.active_socket_set == 'By Matrix':
|
||||
|
@ -157,13 +159,16 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
('COND', 'κ(V)', 'cond(V) (by Mat)'),
|
||||
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
|
||||
('RANK', 'rank V', 'rank(V) (by Mat)'),
|
||||
None,
|
||||
# Matrix -> Array
|
||||
('DIAG', 'diag V', 'diag(V) (by Mat)'),
|
||||
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
|
||||
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
|
||||
None,
|
||||
# Matrix -> Matrix
|
||||
('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
|
||||
('TRA', 'Vt', 'V^T (by Mat)'),
|
||||
None,
|
||||
# Matrix -> Matrices
|
||||
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
|
||||
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
|
||||
|
@ -175,7 +180,9 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
msg = f'Active socket set {self.active_socket_set} is unknown'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return [(*item, '', i) for i, item in enumerate(items)]
|
||||
return [
|
||||
(*item, '', i) if item is not None else None for i, item in enumerate(items)
|
||||
]
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
@ -185,8 +192,9 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
####################
|
||||
@events.on_value_changed(
|
||||
prop_name='active_socket_set',
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_operation_changed(self):
|
||||
def on_socket_set_changed(self):
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
####################
|
||||
|
@ -204,11 +212,14 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
input_sockets_optional={'Mapper': True},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||
if (
|
||||
ct.FlowSignal.check(input_sockets['Data']) or props['operation'] == 'NONE'
|
||||
) or (
|
||||
not has_data
|
||||
or props['operation'] == 'NONE'
|
||||
or (
|
||||
props['active_socket_set'] == 'Expr'
|
||||
and ct.FlowSignal.check(input_sockets['Mapper'])
|
||||
)
|
||||
):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
@ -238,14 +249,14 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
||||
'RANK': lambda data: jnp.linalg.matrix_rank(data),
|
||||
# Matrix -> Vec
|
||||
'DIAG': lambda data: jnp.diag(data),
|
||||
'EIG_VALS': lambda data: jnp.eigvals(data),
|
||||
'SVD_VALS': lambda data: jnp.svdvals(data),
|
||||
'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
|
||||
'EIG_VALS': lambda data: jnp.linalg.eigvals(data),
|
||||
'SVD_VALS': lambda data: jnp.linalg.svdvals(data),
|
||||
# Matrix -> Matrix
|
||||
'INV': lambda data: jnp.inv(data),
|
||||
'INV': lambda data: jnp.linalg.inv(data),
|
||||
'TRA': lambda data: jnp.matrix_transpose(data),
|
||||
# Matrix -> Matrices
|
||||
'QR': lambda data: jnp.inv(data),
|
||||
'QR': lambda data: jnp.linalg.qr(data),
|
||||
'CHOL': lambda data: jnp.linalg.cholesky(data),
|
||||
'SVD': lambda data: jnp.linalg.svd(data),
|
||||
},
|
||||
|
@ -298,28 +309,53 @@ class MapMathNode(base.MaxwellSimNode):
|
|||
return ct.FlowSignal.FlowPending
|
||||
|
||||
# Complex -> Real
|
||||
if props['active_socket_set'] == 'By Element':
|
||||
if props['operation'] in [
|
||||
if props['active_socket_set'] == 'By Element' and props['operation'] in [
|
||||
'REAL',
|
||||
'IMAG',
|
||||
'ABS',
|
||||
]:
|
||||
return ct.InfoFlow(
|
||||
dim_names=info.dim_names,
|
||||
dim_idx=info.dim_idx,
|
||||
output_names=info.output_names,
|
||||
output_mathtypes={
|
||||
output_name: (
|
||||
spux.MathType.Real
|
||||
if output_mathtype == spux.MathType.Complex
|
||||
else output_mathtype
|
||||
return info.set_output_mathtype(spux.MathType.Real)
|
||||
|
||||
if props['active_socket_set'] == 'By Vector' and props['operation'] in [
|
||||
'NORM_2'
|
||||
]:
|
||||
return {
|
||||
'NORM_2': lambda: info.collapse_output(
|
||||
collapsed_name=f'||{info.output_name}||₂',
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=info.output_unit,
|
||||
)
|
||||
for output_name, output_mathtype in info.output_mathtypes.items()
|
||||
},
|
||||
output_units=info.output_units,
|
||||
)
|
||||
if props['active_socket_set'] == 'By Vector':
|
||||
pass
|
||||
}[props['operation']]()
|
||||
|
||||
if props['active_socket_set'] == 'By Matrix' and props['operation'] in [
|
||||
'DET',
|
||||
'COND',
|
||||
'NORM_FRO',
|
||||
'RANK',
|
||||
]:
|
||||
return {
|
||||
'DET': lambda: info.collapse_output(
|
||||
collapsed_name=f'det {info.output_name}',
|
||||
collapsed_mathtype=info.output_mathtype,
|
||||
collapsed_unit=info.output_unit,
|
||||
),
|
||||
'COND': lambda: info.collapse_output(
|
||||
collapsed_name=f'κ({info.output_name})',
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=None,
|
||||
),
|
||||
'NORM_FRO': lambda: info.collapse_output(
|
||||
collapsed_name=f'||({info.output_name}||_F',
|
||||
collapsed_mathtype=spux.MathType.Real,
|
||||
collapsed_unit=info.output_unit,
|
||||
),
|
||||
'RANK': lambda: info.collapse_output(
|
||||
collapsed_name=f'rank {info.output_name}',
|
||||
collapsed_mathtype=spux.MathType.Integer,
|
||||
collapsed_unit=None,
|
||||
),
|
||||
}[props['operation']]()
|
||||
|
||||
return info
|
||||
|
||||
@events.computes_output_socket(
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import enum
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import logger
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
|
@ -13,120 +16,465 @@ log = logger.get(__name__)
|
|||
|
||||
|
||||
class OperateMathNode(base.MaxwellSimNode):
|
||||
r"""Applies a function that depends on two inputs.
|
||||
|
||||
Attributes:
|
||||
category: The category of operations to apply to the inputs.
|
||||
**Only valid** categories can be chosen.
|
||||
operation: The actual operation to apply to the inputs.
|
||||
**Only valid** operations can be chosen.
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.OperateMath
|
||||
bl_label = 'Operate Math'
|
||||
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Elementwise': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
'Expr | Expr': {
|
||||
'Expr L': sockets.ExprSocketDef(),
|
||||
'Expr R': sockets.ExprSocketDef(),
|
||||
},
|
||||
## TODO: Filter-array building operations
|
||||
'Vec-Vec': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
'Data | Data': {
|
||||
'Data L': sockets.DataSocketDef(
|
||||
format='jax', default_show_info_columns=False
|
||||
),
|
||||
'Data R': sockets.DataSocketDef(
|
||||
format='jax', default_show_info_columns=False
|
||||
),
|
||||
},
|
||||
'Mat-Vec': {
|
||||
'Data L': sockets.AnySocketDef(),
|
||||
'Data R': sockets.AnySocketDef(),
|
||||
'Expr | Data': {
|
||||
'Expr L': sockets.ExprSocketDef(),
|
||||
'Data R': sockets.DataSocketDef(
|
||||
format='jax', default_show_info_columns=False
|
||||
),
|
||||
},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.AnySocketDef(),
|
||||
output_socket_sets: typ.ClassVar = {
|
||||
'Expr | Expr': {
|
||||
'Expr': sockets.ExprSocketDef(),
|
||||
},
|
||||
'Data | Data': {
|
||||
'Data': sockets.DataSocketDef(
|
||||
format='jax', default_show_info_columns=False
|
||||
),
|
||||
},
|
||||
'Expr | Data': {
|
||||
'Data': sockets.DataSocketDef(
|
||||
format='jax', default_show_info_columns=False
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: bpy.props.EnumProperty(
|
||||
name='Op',
|
||||
description='Operation to apply to the two inputs',
|
||||
items=lambda self, _: self.search_operations(),
|
||||
update=lambda self, context: self.on_prop_changed('operation', context),
|
||||
category: enum.Enum = bl_cache.BLField(
|
||||
prop_ui=True, enum_cb=lambda self, _: self.search_categories()
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||
operation: enum.Enum = bl_cache.BLField(
|
||||
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||
)
|
||||
|
||||
def search_categories(self) -> list[ct.BLEnumElement]:
|
||||
"""Deduce and return a list of valid categories for the current socket set and input data."""
|
||||
data_l_info = self._compute_input(
|
||||
'Data L', kind=ct.FlowKind.Info, optional=True
|
||||
)
|
||||
data_r_info = self._compute_input(
|
||||
'Data R', kind=ct.FlowKind.Info, optional=True
|
||||
)
|
||||
|
||||
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||
|
||||
# Categories by Socket Set
|
||||
NUMBER_NUMBER = (
|
||||
'Number | Number',
|
||||
'Number | Number',
|
||||
'Operations between numerical elements',
|
||||
)
|
||||
NUMBER_VECTOR = (
|
||||
'Number | Vector',
|
||||
'Number | Vector',
|
||||
'Operations between numerical and vector elements',
|
||||
)
|
||||
NUMBER_MATRIX = (
|
||||
'Number | Matrix',
|
||||
'Number | Matrix',
|
||||
'Operations between numerical and matrix elements',
|
||||
)
|
||||
VECTOR_VECTOR = (
|
||||
'Vector | Vector',
|
||||
'Vector | Vector',
|
||||
'Operations between vector elements',
|
||||
)
|
||||
MATRIX_VECTOR = (
|
||||
'Matrix | Vector',
|
||||
'Matrix | Vector',
|
||||
'Operations between vector and matrix elements',
|
||||
)
|
||||
MATRIX_MATRIX = (
|
||||
'Matrix | Matrix',
|
||||
'Matrix | Matrix',
|
||||
'Operations between matrix elements',
|
||||
)
|
||||
categories = []
|
||||
|
||||
## Expr | Expr
|
||||
if self.active_socket_set == 'Expr | Expr':
|
||||
return [NUMBER_NUMBER]
|
||||
|
||||
## Data | Data
|
||||
if (
|
||||
self.active_socket_set == 'Data | Data'
|
||||
and has_data_l_info
|
||||
and has_data_r_info
|
||||
):
|
||||
# Check Valid Broadcasting
|
||||
## Number | Number
|
||||
if data_l_info.output_shape is None and data_r_info.output_shape is None:
|
||||
categories = [NUMBER_NUMBER]
|
||||
|
||||
## Number | Vector
|
||||
elif (
|
||||
data_l_info.output_shape is None and len(data_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [NUMBER_VECTOR]
|
||||
|
||||
## Number | Matrix
|
||||
elif (
|
||||
data_l_info.output_shape is None and len(data_r_info.output_shape) == 2
|
||||
): # noqa: PLR2004
|
||||
categories = [NUMBER_MATRIX]
|
||||
|
||||
## Vector | Vector
|
||||
elif (
|
||||
len(data_l_info.output_shape) == 1
|
||||
and len(data_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [VECTOR_VECTOR]
|
||||
|
||||
## Matrix | Vector
|
||||
elif (
|
||||
len(data_l_info.output_shape) == 2 # noqa: PLR2004
|
||||
and len(data_r_info.output_shape) == 1
|
||||
):
|
||||
categories = [MATRIX_VECTOR]
|
||||
|
||||
## Matrix | Matrix
|
||||
elif (
|
||||
len(data_l_info.output_shape) == 2 # noqa: PLR2004
|
||||
and len(data_r_info.output_shape) == 2 # noqa: PLR2004
|
||||
):
|
||||
categories = [MATRIX_MATRIX]
|
||||
|
||||
## Expr | Data
|
||||
if self.active_socket_set == 'Expr | Data' and has_data_r_info:
|
||||
if data_r_info.output_shape is None:
|
||||
categories = [NUMBER_NUMBER]
|
||||
else:
|
||||
categories = {
|
||||
1: [NUMBER_NUMBER, NUMBER_VECTOR],
|
||||
2: [NUMBER_NUMBER, NUMBER_MATRIX],
|
||||
}[len(data_r_info.output_shape)]
|
||||
|
||||
return [
|
||||
(*category, '', i) if category is not None else None
|
||||
for i, category in enumerate(categories)
|
||||
]
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
items = []
|
||||
if self.active_socket_set == 'Elementwise':
|
||||
items = [
|
||||
('ADD', 'Add', 'L + R (by el)'),
|
||||
('SUB', 'Subtract', 'L - R (by el)'),
|
||||
('MUL', 'Multiply', 'L · R (by el)'),
|
||||
('DIV', 'Divide', 'L ÷ R (by el)'),
|
||||
('POW', 'Power', 'L^R (by el)'),
|
||||
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
|
||||
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
|
||||
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
|
||||
if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']:
|
||||
items += [
|
||||
('ADD', 'L + R', 'Add'),
|
||||
('SUB', 'L - R', 'Subtract'),
|
||||
('MUL', 'L · R', 'Multiply'),
|
||||
('DIV', 'L ÷ R', 'Divide'),
|
||||
('POW', 'L^R', 'Power'),
|
||||
('ATAN2', 'atan2(L,R)', 'atan2(L,R)'),
|
||||
]
|
||||
elif self.active_socket_set in 'Vec | Vec':
|
||||
items = [
|
||||
('DOT', 'Dot', 'L · R'),
|
||||
('CROSS', 'Cross', 'L x R (by last-axis'),
|
||||
if self.category in 'Vector | Vector':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'),
|
||||
('CROSS', 'L x R', 'Cross Product'),
|
||||
('PROJ', 'proj(L, R)', 'Projection'),
|
||||
]
|
||||
elif self.active_socket_set == 'Mat | Vec':
|
||||
items = [
|
||||
('DOT', 'Dot', 'L · R'),
|
||||
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
|
||||
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
|
||||
if self.category == 'Matrix | Vector':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'),
|
||||
('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'),
|
||||
('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'),
|
||||
]
|
||||
if self.category == 'Matrix | Matrix':
|
||||
if items:
|
||||
items += [None]
|
||||
items += [
|
||||
('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'),
|
||||
]
|
||||
|
||||
return [
|
||||
(*item, '', i) if item is not None else None for i, item in enumerate(items)
|
||||
]
|
||||
return items
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, 'operation')
|
||||
layout.prop(self, self.blfields['category'], text='')
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
# Trigger
|
||||
socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'},
|
||||
prop_name='active_socket_set',
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_socket_set_changed(self) -> None:
|
||||
# Recompute Valid Categories
|
||||
self.category = bl_cache.Signal.ResetEnumItems
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
@events.on_value_changed(
|
||||
prop_name='category',
|
||||
run_on_init=True,
|
||||
)
|
||||
def on_category_changed(self) -> None:
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
####################
|
||||
# - Output
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Expr',
|
||||
kind=ct.FlowKind.Value,
|
||||
props={'operation'},
|
||||
input_sockets={'Expr L', 'Expr R'},
|
||||
)
|
||||
def compute_expr(self, props: dict, input_sockets: dict):
|
||||
expr_l = input_sockets['Expr L']
|
||||
expr_r = input_sockets['Expr R']
|
||||
|
||||
return {
|
||||
'ADD': lambda: expr_l + expr_r,
|
||||
'SUB': lambda: expr_l - expr_r,
|
||||
'MUL': lambda: expr_l * expr_r,
|
||||
'DIV': lambda: expr_l / expr_r,
|
||||
'POW': lambda: expr_l**expr_r,
|
||||
'ATAN2': lambda: sp.atan2(expr_r, expr_l),
|
||||
}[props['operation']]()
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'operation'},
|
||||
input_sockets={'Data L', 'Data R'},
|
||||
input_socket_kinds={
|
||||
'Data L': ct.FlowKind.LazyValueFunc,
|
||||
'Data R': ct.FlowKind.LazyValueFunc,
|
||||
},
|
||||
input_sockets_optional={
|
||||
'Data L': True,
|
||||
'Data R': True,
|
||||
},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
data_l = input_sockets['Data L']
|
||||
data_r = input_sockets['Data R']
|
||||
has_data_l = not ct.FlowSignal.check(data_l)
|
||||
|
||||
mapping_func = {
|
||||
# Number | *
|
||||
'ADD': lambda datas: datas[0] + datas[1],
|
||||
'SUB': lambda datas: datas[0] - datas[1],
|
||||
'MUL': lambda datas: datas[0] * datas[1],
|
||||
'DIV': lambda datas: datas[0] / datas[1],
|
||||
'POW': lambda datas: datas[0] ** datas[1],
|
||||
'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]),
|
||||
# Vector | Vector
|
||||
'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||
'CROSS': lambda datas: jnp.cross(datas[0], datas[1]),
|
||||
# Matrix | Vector
|
||||
'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||
'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]),
|
||||
'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]),
|
||||
# Matrix | Matrix
|
||||
'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||
}[props['operation']]
|
||||
|
||||
# Compose by Socket Set
|
||||
## Data | Data
|
||||
if has_data_l:
|
||||
return (data_l | data_r).compose_within(
|
||||
mapping_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
## Expr | Data
|
||||
expr_l_lazy_value_func = ct.LazyValueFuncFlow(
|
||||
func=lambda expr_l_value: expr_l_value,
|
||||
func_args=[typ.Any],
|
||||
supports_jax=True,
|
||||
)
|
||||
return (expr_l_lazy_value_func | data_r).compose_within(
|
||||
mapping_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Data'},
|
||||
output_socket_kinds={
|
||||
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
)
|
||||
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||
|
||||
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||
has_params = not ct.FlowSignal.check(params)
|
||||
|
||||
if has_lazy_value_func and has_params:
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.func_args, **params.func_kwargs
|
||||
),
|
||||
unit=None,
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Params,
|
||||
props={'operation'},
|
||||
input_sockets={'Data L', 'Data R'},
|
||||
input_sockets={'Expr L', 'Data L', 'Data R'},
|
||||
input_socket_kinds={
|
||||
'Expr L': ct.FlowKind.Value,
|
||||
'Data L': {ct.FlowKind.Info, ct.FlowKind.Params},
|
||||
'Data R': {ct.FlowKind.Info, ct.FlowKind.Params},
|
||||
},
|
||||
input_sockets_optional={
|
||||
'Expr L': True,
|
||||
'Data L': True,
|
||||
'Data R': True,
|
||||
},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
if self.active_socket_set == 'Elementwise':
|
||||
# Element-Wise Arithmetic
|
||||
if props['operation'] == 'ADD':
|
||||
return input_sockets['Data L'] + input_sockets['Data R']
|
||||
if props['operation'] == 'SUB':
|
||||
return input_sockets['Data L'] - input_sockets['Data R']
|
||||
if props['operation'] == 'MUL':
|
||||
return input_sockets['Data L'] * input_sockets['Data R']
|
||||
if props['operation'] == 'DIV':
|
||||
return input_sockets['Data L'] / input_sockets['Data R']
|
||||
def compute_data_params(
|
||||
self, props, input_sockets
|
||||
) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
expr_l = input_sockets['Expr L']
|
||||
data_l_info = input_sockets['Data L'][ct.FlowKind.Info]
|
||||
data_l_params = input_sockets['Data L'][ct.FlowKind.Params]
|
||||
data_r_info = input_sockets['Data R'][ct.FlowKind.Info]
|
||||
data_r_params = input_sockets['Data R'][ct.FlowKind.Params]
|
||||
|
||||
# Element-Wise Arithmetic
|
||||
if props['operation'] == 'POW':
|
||||
return input_sockets['Data L'] ** input_sockets['Data R']
|
||||
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||
has_data_l_params = not ct.FlowSignal.check(data_l_params)
|
||||
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||
has_data_r_params = not ct.FlowSignal.check(data_r_params)
|
||||
|
||||
# Binary Trigonometry
|
||||
if props['operation'] == 'ATAN2':
|
||||
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
|
||||
#log.critical((props, input_sockets))
|
||||
|
||||
# Special Functions
|
||||
if props['operation'] == 'HEAVISIDE':
|
||||
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
|
||||
# Compose by Socket Set
|
||||
## Data | Data
|
||||
if (
|
||||
has_data_l_info
|
||||
and has_data_l_params
|
||||
and has_data_r_info
|
||||
and has_data_r_params
|
||||
):
|
||||
return data_l_params | data_r_params
|
||||
|
||||
# Linear Algebra
|
||||
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
|
||||
if props['operation'] == 'DOT':
|
||||
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
|
||||
## Expr | Data
|
||||
if has_expr_l and has_data_r_info and has_data_r_params:
|
||||
operation = props['operation']
|
||||
data_unit = data_r_info.output_unit
|
||||
|
||||
elif self.active_socket_set == 'Vec-Vec':
|
||||
if props['operation'] == 'CROSS':
|
||||
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
|
||||
# By Operation
|
||||
## Add/Sub: Scale to Output Unit
|
||||
if operation in ['ADD', 'SUB', 'MUL', 'DIV']:
|
||||
if not spux.uses_units(expr_l):
|
||||
value = spux.sympy_to_python(expr_l)
|
||||
else:
|
||||
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||
|
||||
elif self.active_socket_set == 'Mat-Vec':
|
||||
if props['operation'] == 'LIN_SOLVE':
|
||||
return jnp.linalg.lstsq(
|
||||
input_sockets['Data L'], input_sockets['Data R']
|
||||
)
|
||||
if props['operation'] == 'LSQ_SOLVE':
|
||||
return jnp.linalg.solve(
|
||||
input_sockets['Data L'], input_sockets['Data R']
|
||||
return data_r_params.compose_within(
|
||||
enclosing_func_args=[value],
|
||||
)
|
||||
|
||||
msg = 'Invalid operation'
|
||||
raise ValueError(msg)
|
||||
## Pow: Doesn't Exist (?)
|
||||
## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function
|
||||
if operation == 'POW':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
## atan2(): Only Length
|
||||
## -> Implicitly presume that Data L/R use length units.
|
||||
if operation == 'ATAN2':
|
||||
if not spux.uses_units(expr_l):
|
||||
value = spux.sympy_to_python(expr_l)
|
||||
else:
|
||||
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||
|
||||
return data_r_params.compose_within(
|
||||
enclosing_func_args=[value],
|
||||
)
|
||||
|
||||
return data_r_params.compose_within(
|
||||
enclosing_func_args=[
|
||||
spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||
]
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Auxiliary: Info
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Info,
|
||||
input_sockets={'Expr L', 'Data L', 'Data R'},
|
||||
input_socket_kinds={
|
||||
'Expr L': ct.FlowKind.Value,
|
||||
'Data L': ct.FlowKind.Info,
|
||||
'Data R': ct.FlowKind.Info,
|
||||
},
|
||||
input_sockets_optional={
|
||||
'Expr L': True,
|
||||
'Data L': True,
|
||||
'Data R': True,
|
||||
},
|
||||
)
|
||||
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow:
|
||||
expr_l = input_sockets['Expr L']
|
||||
data_l_info = input_sockets['Data L']
|
||||
data_r_info = input_sockets['Data R']
|
||||
|
||||
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||
|
||||
# Info by Socket Set
|
||||
## Data | Data
|
||||
if has_data_l_info and has_data_r_info:
|
||||
return data_r_info
|
||||
|
||||
## Expr | Data
|
||||
if has_expr_l and has_data_r_info:
|
||||
return data_r_info
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
"""Declares `TransformMathNode`."""
|
||||
|
||||
import enum
|
||||
import typing as typ
|
||||
|
||||
import bpy
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from blender_maxwell.utils import bl_cache, logger
|
||||
from blender_maxwell.utils import extra_sympy_units as spux
|
||||
|
||||
from .... import contracts as ct
|
||||
from .... import sockets
|
||||
from ... import base, events
|
||||
|
||||
log = logger.get(__name__)
|
||||
|
||||
|
||||
class TransformMathNode(base.MaxwellSimNode):
|
||||
r"""Applies a function to the array as a whole, with arbitrary results.
|
||||
|
||||
The shape, type, and interpretation of the input/output data is dynamically shown.
|
||||
|
||||
# Socket Sets
|
||||
## Interpret
|
||||
Reinterprets the `InfoFlow` of an array, **without changing it**.
|
||||
|
||||
Attributes:
|
||||
operation: Operation to apply to the input.
|
||||
"""
|
||||
|
||||
node_type = ct.NodeType.TransformMath
|
||||
bl_label = 'Transform Math'
|
||||
|
||||
input_sockets: typ.ClassVar = {
|
||||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
}
|
||||
input_socket_sets: typ.ClassVar = {
|
||||
'Fourier': {},
|
||||
'Affine': {},
|
||||
'Convolve': {},
|
||||
}
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Data': sockets.DataSocketDef(format='jax'),
|
||||
}
|
||||
|
||||
####################
|
||||
# - Properties
|
||||
####################
|
||||
operation: enum.Enum = bl_cache.BLField(
|
||||
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||
)
|
||||
|
||||
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||
if self.active_socket_set == 'Fourier': # noqa: SIM114
|
||||
items = []
|
||||
elif self.active_socket_set == 'Affine': # noqa: SIM114
|
||||
items = []
|
||||
elif self.active_socket_set == 'Convolve':
|
||||
items = []
|
||||
else:
|
||||
msg = f'Active socket set {self.active_socket_set} is unknown'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return [(*item, '', i) for i, item in enumerate(items)]
|
||||
|
||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||
layout.prop(self, self.blfields['operation'], text='')
|
||||
|
||||
####################
|
||||
# - Events
|
||||
####################
|
||||
@events.on_value_changed(
|
||||
prop_name='active_socket_set',
|
||||
)
|
||||
def on_socket_set_changed(self):
|
||||
self.operation = bl_cache.Signal.ResetEnumItems
|
||||
|
||||
####################
|
||||
# - Compute: LazyValueFunc / Array
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.LazyValueFunc,
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={
|
||||
'Data': ct.FlowKind.LazyValueFunc,
|
||||
},
|
||||
)
|
||||
def compute_data(self, props: dict, input_sockets: dict):
|
||||
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||
if not has_data or props['operation'] == 'NONE':
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
|
||||
'Fourier': {},
|
||||
'Affine': {},
|
||||
'Convolve': {},
|
||||
}[props['active_socket_set']][props['operation']]
|
||||
|
||||
# Compose w/Lazy Root Function Data
|
||||
return input_sockets['Data'].compose_within(
|
||||
mapping_func,
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Array,
|
||||
output_sockets={'Data'},
|
||||
output_socket_kinds={
|
||||
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||
},
|
||||
)
|
||||
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||
|
||||
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
||||
return ct.ArrayFlow(
|
||||
values=lazy_value_func.func_jax(
|
||||
*params.func_args, **params.func_kwargs
|
||||
),
|
||||
unit=None,
|
||||
)
|
||||
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
####################
|
||||
# - Compute Auxiliary: Info / Params
|
||||
####################
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Info,
|
||||
props={'active_socket_set', 'operation'},
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||
)
|
||||
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||
info = input_sockets['Data']
|
||||
if ct.FlowSignal.check(info):
|
||||
return ct.FlowSignal.FlowPending
|
||||
|
||||
return info
|
||||
|
||||
@events.computes_output_socket(
|
||||
'Data',
|
||||
kind=ct.FlowKind.Params,
|
||||
input_sockets={'Data'},
|
||||
input_socket_kinds={'Data': ct.FlowKind.Params},
|
||||
)
|
||||
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||
return input_sockets['Data']
|
||||
|
||||
|
||||
####################
|
||||
# - Blender Registration
|
||||
####################
|
||||
BL_REGISTER = [
|
||||
TransformMathNode,
|
||||
]
|
||||
BL_NODES = {ct.NodeType.TransformMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -61,32 +61,34 @@ class VizMode(enum.StrEnum):
|
|||
|
||||
@staticmethod
|
||||
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||
EMPTY = ()
|
||||
Z = spux.MathType.Integer
|
||||
R = spux.MathType.Real
|
||||
VM = VizMode
|
||||
|
||||
valid_viz_modes = {
|
||||
((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D],
|
||||
((spux.MathType.Integer), (spux.MathType.Real)): [
|
||||
VizMode.Hist1D,
|
||||
VizMode.BoxPlot1D,
|
||||
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
|
||||
((Z), (None, R)): [
|
||||
VM.Hist1D,
|
||||
VM.BoxPlot1D,
|
||||
],
|
||||
((spux.MathType.Real,), (spux.MathType.Real,)): [
|
||||
VizMode.Curve2D,
|
||||
VizMode.Points2D,
|
||||
VizMode.Bar,
|
||||
((R,), (None, R)): [
|
||||
VM.Curve2D,
|
||||
VM.Points2D,
|
||||
VM.Bar,
|
||||
],
|
||||
((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [
|
||||
VizMode.Curves2D,
|
||||
VizMode.FilledCurves2D,
|
||||
((R, Z), (None, R)): [
|
||||
VM.Curves2D,
|
||||
VM.FilledCurves2D,
|
||||
],
|
||||
((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [
|
||||
VizMode.Heatmap2D,
|
||||
((R, R), (None, R)): [
|
||||
VM.Heatmap2D,
|
||||
],
|
||||
(
|
||||
(spux.MathType.Real, spux.MathType.Real, spux.MathType.Real),
|
||||
(spux.MathType.Real,),
|
||||
): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D],
|
||||
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
|
||||
}.get(
|
||||
(
|
||||
tuple(info.dim_mathtypes.values()),
|
||||
tuple(info.output_mathtypes.values()),
|
||||
(info.output_shape, info.output_mathtype),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -161,10 +163,10 @@ class VizTarget(enum.StrEnum):
|
|||
@staticmethod
|
||||
def to_name(value: typ.Self) -> str:
|
||||
return {
|
||||
VizTarget.Plot2D: 'Image (Plot)',
|
||||
VizTarget.Pixels: 'Image (Pixels)',
|
||||
VizTarget.PixelsPlane: 'Image (Plane)',
|
||||
VizTarget.Voxels: '3D Field',
|
||||
VizTarget.Plot2D: 'Plot',
|
||||
VizTarget.Pixels: 'Pixels',
|
||||
VizTarget.PixelsPlane: 'Image Plane',
|
||||
VizTarget.Voxels: 'Voxels',
|
||||
}[value]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1036,6 +1036,12 @@ class MaxwellSimNode(bpy.types.Node):
|
|||
# Generate New Instance ID
|
||||
self.reset_instance_id()
|
||||
|
||||
# Generate New Instance ID for Sockets
|
||||
## Sockets can't do this themselves.
|
||||
for bl_sockets in [self.inputs, self.outputs]:
|
||||
for bl_socket in bl_sockets:
|
||||
bl_socket.reset_instance_id()
|
||||
|
||||
# Generate New Sim Node Name
|
||||
## Blender will automatically add .001 so that `self.name` is unique.
|
||||
self.sim_node_name = self.name
|
||||
|
|
|
@ -14,7 +14,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
|||
bl_label = 'Scientific Constant'
|
||||
|
||||
output_sockets: typ.ClassVar = {
|
||||
'Value': sockets.AnySocketDef(),
|
||||
'Value': sockets.ExprSocketDef(),
|
||||
}
|
||||
|
||||
####################
|
||||
|
|
|
@ -913,10 +913,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
|
|||
col = layout.column()
|
||||
row = col.row()
|
||||
row.alignment = 'RIGHT'
|
||||
if self.is_linked:
|
||||
self.draw_output_label_row(row, text)
|
||||
else:
|
||||
row.label(text=text)
|
||||
|
||||
# Draw FlowKind.Info related Information
|
||||
if self.use_info_draw:
|
||||
|
|
|
@ -12,6 +12,10 @@ from .. import base
|
|||
log = logger.get(__name__)
|
||||
|
||||
|
||||
def unicode_superscript(n):
|
||||
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
|
||||
|
||||
|
||||
class DataInfoColumn(enum.StrEnum):
|
||||
Length = enum.auto()
|
||||
MathType = enum.auto()
|
||||
|
@ -49,11 +53,11 @@ class DataBLSocket(base.MaxwellSimSocket):
|
|||
## TODO: typ.Literal['xarray', 'jax']
|
||||
|
||||
show_info_columns: bool = bl_cache.BLField(
|
||||
False,
|
||||
True,
|
||||
prop_ui=True,
|
||||
)
|
||||
info_columns: DataInfoColumn = bl_cache.BLField(
|
||||
{DataInfoColumn.MathType, DataInfoColumn.Length}, prop_ui=True, enum_many=True
|
||||
{DataInfoColumn.MathType, DataInfoColumn.Unit}, prop_ui=True, enum_many=True
|
||||
)
|
||||
|
||||
####################
|
||||
|
@ -71,8 +75,10 @@ class DataBLSocket(base.MaxwellSimSocket):
|
|||
# - UI
|
||||
####################
|
||||
def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None:
|
||||
if self.format == 'jax':
|
||||
row.label(text=text)
|
||||
|
||||
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
|
||||
row.prop(self, self.blfields['info_columns'])
|
||||
row.prop(
|
||||
self,
|
||||
|
@ -83,7 +89,8 @@ class DataBLSocket(base.MaxwellSimSocket):
|
|||
)
|
||||
|
||||
def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None:
|
||||
if self.format == 'jax':
|
||||
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
|
||||
row.prop(
|
||||
self,
|
||||
self.blfields['show_info_columns'],
|
||||
|
@ -92,6 +99,7 @@ class DataBLSocket(base.MaxwellSimSocket):
|
|||
icon=ct.Icon.ToggleSocketInfo,
|
||||
)
|
||||
row.prop(self, self.blfields['info_columns'])
|
||||
|
||||
row.label(text=text)
|
||||
|
||||
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
||||
|
@ -118,16 +126,27 @@ class DataBLSocket(base.MaxwellSimSocket):
|
|||
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
||||
|
||||
# Outputs
|
||||
for output_name in info.output_names:
|
||||
grid.label(text=output_name)
|
||||
grid.label(text=info.output_name)
|
||||
if DataInfoColumn.Length in self.info_columns:
|
||||
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
||||
if DataInfoColumn.MathType in self.info_columns:
|
||||
grid.label(
|
||||
text=spux.MathType.to_str(info.output_mathtypes[output_name])
|
||||
text=(
|
||||
spux.MathType.to_str(info.output_mathtype)
|
||||
+ (
|
||||
'ˣ'.join(
|
||||
[
|
||||
unicode_superscript(out_axis)
|
||||
for out_axis in info.output_shape
|
||||
]
|
||||
)
|
||||
if info.output_shape
|
||||
else ''
|
||||
)
|
||||
)
|
||||
)
|
||||
if DataInfoColumn.Unit in self.info_columns:
|
||||
grid.label(text=spux.sp_to_str(info.output_units[output_name]))
|
||||
grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
|
||||
|
||||
|
||||
####################
|
||||
|
@ -137,9 +156,11 @@ class DataSocketDef(base.SocketDef):
|
|||
socket_type: ct.SocketType = ct.SocketType.Data
|
||||
|
||||
format: typ.Literal['xarray', 'jax', 'monitor_data']
|
||||
default_show_info_columns: bool = True
|
||||
|
||||
def init(self, bl_socket: DataBLSocket) -> None:
|
||||
bl_socket.format = self.format
|
||||
bl_socket.default_show_info_columns = self.default_show_info_columns
|
||||
|
||||
|
||||
####################
|
||||
|
|
|
@ -86,9 +86,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
|||
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
|
||||
return ct.LazyValueFuncFlow(
|
||||
func=sp.lambdify(self.symbols, self.value, 'jax'),
|
||||
func_args=[
|
||||
(sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols
|
||||
],
|
||||
func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols],
|
||||
supports_jax=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -31,6 +31,18 @@ class MathType(enum.StrEnum):
|
|||
Real = enum.auto()
|
||||
Complex = enum.auto()
|
||||
|
||||
def combine(*mathtypes: list[typ.Self]) -> typ.Self:
|
||||
if MathType.Complex in mathtypes:
|
||||
return MathType.Complex
|
||||
elif MathType.Real in mathtypes:
|
||||
return MathType.Real
|
||||
elif MathType.Rational in mathtypes:
|
||||
return MathType.Rational
|
||||
elif MathType.Integer in mathtypes:
|
||||
return MathType.Integer
|
||||
elif MathType.Bool in mathtypes:
|
||||
return MathType.Bool
|
||||
|
||||
@staticmethod
|
||||
def from_expr(sp_obj: SympyType) -> type:
|
||||
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
|
||||
|
@ -595,6 +607,21 @@ ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
|
|||
)
|
||||
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
|
||||
|
||||
# Number
|
||||
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
|
||||
allow_variables=False,
|
||||
allow_units=True,
|
||||
allowed_sets={'integer', 'rational', 'real'},
|
||||
allowed_structures={'scalar'},
|
||||
)
|
||||
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
|
||||
allow_variables=False,
|
||||
allow_units=True,
|
||||
allowed_sets={'integer', 'rational', 'real', 'complex'},
|
||||
allowed_structures={'scalar'},
|
||||
)
|
||||
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
|
||||
|
||||
# Vector
|
||||
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
|
||||
allow_variables=False,
|
||||
|
|
|
@ -117,8 +117,8 @@ def rgba_image_from_2d_map(
|
|||
def plot_hist_1d(
|
||||
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.hist(data, bins=30, alpha=0.75)
|
||||
ax.set_title('Histogram')
|
||||
|
@ -130,8 +130,8 @@ def plot_box_plot_1d(
|
|||
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
||||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.boxplot(data)
|
||||
ax.set_title('Box Plot')
|
||||
|
@ -147,8 +147,8 @@ def plot_curve_2d(
|
|||
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
times.append(time.perf_counter() - times[0])
|
||||
ax.plot(info.dim_idx_arrays[0], data)
|
||||
|
@ -167,8 +167,8 @@ def plot_points_2d(
|
|||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
|
||||
ax.set_title('2D Points')
|
||||
|
@ -179,8 +179,8 @@ def plot_points_2d(
|
|||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
|
||||
ax.set_title('2D Bar')
|
||||
|
@ -194,8 +194,8 @@ def plot_curves_2d(
|
|||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
for category in range(data.shape[1]):
|
||||
ax.plot(data[:, 0], data[:, 1])
|
||||
|
@ -211,8 +211,8 @@ def plot_filled_curves_2d(
|
|||
) -> None:
|
||||
x_name = info.dim_names[0]
|
||||
x_unit = info.dim_units[x_name]
|
||||
y_name = info.output_names[0]
|
||||
y_unit = info.output_units[y_name]
|
||||
y_name = info.output_name
|
||||
y_unit = info.output_unit
|
||||
|
||||
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
|
||||
ax.set_title('2D Curves')
|
||||
|
|
Loading…
Reference in New Issue