feat: Implemented operate math node.
parent
7d944a704e
commit
b2a7eefb45
4
TODO.md
4
TODO.md
|
@ -14,6 +14,8 @@
|
||||||
- [x] Extract
|
- [x] Extract
|
||||||
- [x] Viz
|
- [x] Viz
|
||||||
- [x] Math / Map Math
|
- [x] Math / Map Math
|
||||||
|
- [ ] Remove "By x" socket set let socket sets only be "Function"/"Expr"; then add a dynamic enum underneath to select "By x" based on data support.
|
||||||
|
- [ ] Filter the operations based on data support, ex. use positive-definiteness to guide cholesky.
|
||||||
- [x] Math / Filter Math
|
- [x] Math / Filter Math
|
||||||
- [ ] Math / Reduce Math
|
- [ ] Math / Reduce Math
|
||||||
- [ ] Math / Operate Math
|
- [ ] Math / Operate Math
|
||||||
|
@ -34,8 +36,6 @@
|
||||||
- [x] Constants / Blender Constant
|
- [x] Constants / Blender Constant
|
||||||
|
|
||||||
- [ ] Web / Tidy3D Web Importer
|
- [ ] Web / Tidy3D Web Importer
|
||||||
- [ ] Change to output only a `FilePath`, which can be plugged into a Tidy3D File Importer.
|
|
||||||
- [ ] Implement caching, such that the file will only download if the file doesn't already exist.
|
|
||||||
- [ ] Have a visual indicator for the current download status, with a manual re-download button.
|
- [ ] Have a visual indicator for the current download status, with a manual re-download button.
|
||||||
|
|
||||||
- [x] File Import / Material Import
|
- [x] File Import / Material Import
|
||||||
|
|
|
@ -68,6 +68,9 @@ class FlowKind(enum.StrEnum):
|
||||||
if kind == cls.LazyArrayRange:
|
if kind == cls.LazyArrayRange:
|
||||||
return value.rescale_to_unit(unit_system[socket_type])
|
return value.rescale_to_unit(unit_system[socket_type])
|
||||||
|
|
||||||
|
if kind == cls.Params:
|
||||||
|
return value.rescale_to_unit(unit_system[socket_type])
|
||||||
|
|
||||||
msg = 'Tried to scale unknown kind'
|
msg = 'Tried to scale unknown kind'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -322,6 +325,28 @@ class LazyValueFuncFlow:
|
||||||
supports_jax: bool = False
|
supports_jax: bool = False
|
||||||
supports_numba: bool = False
|
supports_numba: bool = False
|
||||||
|
|
||||||
|
# Merging
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: typ.Self,
|
||||||
|
):
|
||||||
|
return LazyValueFuncFlow(
|
||||||
|
func=lambda *args, **kwargs: (
|
||||||
|
self.func(
|
||||||
|
*list(args[: len(self.func_args)]),
|
||||||
|
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
|
||||||
|
),
|
||||||
|
other.func(
|
||||||
|
*list(args[len(self.func_args) :]),
|
||||||
|
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
func_args=self.func_args + other.func_args,
|
||||||
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
|
supports_jax=self.supports_jax and other.supports_jax,
|
||||||
|
supports_numba=self.supports_numba and other.supports_numba,
|
||||||
|
)
|
||||||
|
|
||||||
# Composition
|
# Composition
|
||||||
def compose_within(
|
def compose_within(
|
||||||
self,
|
self,
|
||||||
|
@ -691,10 +716,21 @@ class ParamsFlow:
|
||||||
func_args: list[typ.Any] = dataclasses.field(default_factory=list)
|
func_args: list[typ.Any] = dataclasses.field(default_factory=list)
|
||||||
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
|
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: typ.Self,
|
||||||
|
):
|
||||||
|
return ParamsFlow(
|
||||||
|
func_args=self.func_args + other.func_args,
|
||||||
|
func_kwargs=self.func_kwargs | other.func_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def compose_within(
|
def compose_within(
|
||||||
self,
|
self,
|
||||||
enclosing_func_args: list[tuple[type]] = (),
|
enclosing_func_args: list[tuple[type]] = (),
|
||||||
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
|
||||||
|
enclosing_func_arg_units: dict[str, type] = MappingProxyType({}),
|
||||||
|
enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}),
|
||||||
) -> typ.Self:
|
) -> typ.Self:
|
||||||
return ParamsFlow(
|
return ParamsFlow(
|
||||||
func_args=self.func_args + list(enclosing_func_args),
|
func_args=self.func_args + list(enclosing_func_args),
|
||||||
|
@ -718,26 +754,133 @@ class InfoFlow:
|
||||||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def dim_mathtypes(self) -> dict[str, int]:
|
def dim_mathtypes(self) -> dict[str, spux.MathType]:
|
||||||
return {
|
return {
|
||||||
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
|
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def dim_units(self) -> dict[str, int]:
|
def dim_units(self) -> dict[str, spux.Unit]:
|
||||||
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
|
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def dim_idx_arrays(self) -> list[ArrayFlow]:
|
def dim_idx_arrays(self) -> list[jax.Array]:
|
||||||
return [
|
return [
|
||||||
dim_idx.realize().values
|
dim_idx.realize().values
|
||||||
if isinstance(dim_idx, LazyArrayRangeFlow)
|
if isinstance(dim_idx, LazyArrayRangeFlow)
|
||||||
else dim_idx.values
|
else dim_idx.values
|
||||||
for dim_idx in self.dim_idx.values()
|
for dim_idx in self.dim_idx.values()
|
||||||
]
|
]
|
||||||
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
|
|
||||||
|
|
||||||
# Output Information
|
# Output Information
|
||||||
output_names: list[str] = dataclasses.field(default_factory=list)
|
output_name: str = dataclasses.field(default_factory=list)
|
||||||
output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict)
|
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
|
||||||
output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict)
|
output_mathtype: spux.MathType = dataclasses.field()
|
||||||
|
output_unit: spux.Unit | None = dataclasses.field()
|
||||||
|
|
||||||
|
# Pinned Dimension Information
|
||||||
|
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
|
||||||
|
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Methods
|
||||||
|
####################
|
||||||
|
def delete_dimension(self, dim_name: str) -> typ.Self:
|
||||||
|
"""Delete a dimension."""
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=[
|
||||||
|
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
|
||||||
|
],
|
||||||
|
dim_idx={
|
||||||
|
_dim_name: dim_idx
|
||||||
|
for _dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
if _dim_name != dim_name
|
||||||
|
},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
|
||||||
|
"""Delete a dimension."""
|
||||||
|
|
||||||
|
# Compute Swapped Dimension Name List
|
||||||
|
def name_swapper(dim_name):
|
||||||
|
return (
|
||||||
|
dim_name
|
||||||
|
if dim_name not in [dim_0_name, dim_1_name]
|
||||||
|
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
|
||||||
|
|
||||||
|
# Compute Info
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=dim_names,
|
||||||
|
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
|
||||||
|
"""Set the MathType of a particular output name."""
|
||||||
|
return InfoFlow(
|
||||||
|
dim_names=self.dim_names,
|
||||||
|
dim_idx=self.dim_idx,
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=self.output_shape,
|
||||||
|
output_mathtype=output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collapse_output(
|
||||||
|
self,
|
||||||
|
collapsed_name: str,
|
||||||
|
collapsed_mathtype: spux.MathType,
|
||||||
|
collapsed_unit: spux.Unit,
|
||||||
|
) -> typ.Self:
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=self.dim_names,
|
||||||
|
dim_idx=self.dim_idx,
|
||||||
|
output_name=collapsed_name,
|
||||||
|
output_shape=None,
|
||||||
|
output_mathtype=collapsed_mathtype,
|
||||||
|
output_unit=collapsed_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def shift_last_input(self):
|
||||||
|
"""Shift the last input dimension to the output."""
|
||||||
|
return InfoFlow(
|
||||||
|
# Dimensions
|
||||||
|
dim_names=self.dim_names[:-1],
|
||||||
|
dim_idx={
|
||||||
|
dim_name: dim_idx
|
||||||
|
for dim_name, dim_idx in self.dim_idx.items()
|
||||||
|
if dim_name != self.dim_names[-1]
|
||||||
|
},
|
||||||
|
# Outputs
|
||||||
|
output_name=self.output_name,
|
||||||
|
output_shape=(
|
||||||
|
(self.dim_lens[self.dim_names[-1]],)
|
||||||
|
if self.output_shape is None
|
||||||
|
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
|
||||||
|
),
|
||||||
|
output_mathtype=self.output_mathtype,
|
||||||
|
output_unit=self.output_unit,
|
||||||
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
|
||||||
FilterMath = enum.auto()
|
FilterMath = enum.auto()
|
||||||
ReduceMath = enum.auto()
|
ReduceMath = enum.auto()
|
||||||
OperateMath = enum.auto()
|
OperateMath = enum.auto()
|
||||||
|
TransformMath = enum.auto()
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
WaveConstant = enum.auto()
|
WaveConstant = enum.auto()
|
||||||
|
|
|
@ -260,7 +260,9 @@ SOCKET_UNITS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def unit_to_socket_type(unit: spux.Unit) -> ST:
|
def unit_to_socket_type(
|
||||||
|
unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None
|
||||||
|
) -> ST:
|
||||||
"""Returns a SocketType that accepts the given unit.
|
"""Returns a SocketType that accepts the given unit.
|
||||||
|
|
||||||
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
|
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
|
||||||
|
@ -269,6 +271,14 @@ def unit_to_socket_type(unit: spux.Unit) -> ST:
|
||||||
Returns:
|
Returns:
|
||||||
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
|
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
|
||||||
"""
|
"""
|
||||||
|
if unit is None and fallback_mathtype is not None:
|
||||||
|
return {
|
||||||
|
spux.MathType.Integer: ST.IntegerNumber,
|
||||||
|
spux.MathType.Rational: ST.RationalNumber,
|
||||||
|
spux.MathType.Real: ST.RealNumber,
|
||||||
|
spux.MathType.Complex: ST.ComplexNumber,
|
||||||
|
}[fallback_mathtype]
|
||||||
|
|
||||||
for socket_type, _units in SOCKET_UNITS.items():
|
for socket_type, _units in SOCKET_UNITS.items():
|
||||||
if unit in _units['values'].values():
|
if unit in _units['values'].values():
|
||||||
return socket_type
|
return socket_type
|
||||||
|
|
|
@ -215,6 +215,15 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
|
def draw_label(self):
|
||||||
|
has_sim_data = self.sim_data_monitor_nametype is not None
|
||||||
|
has_monitor_data = self.monitor_data_components is not None
|
||||||
|
|
||||||
|
if has_sim_data or has_monitor_data:
|
||||||
|
return f'Extract: {self.extract_filter}'
|
||||||
|
|
||||||
|
return self.bl_label
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
||||||
"""Draw node properties in the node.
|
"""Draw node properties in the node.
|
||||||
|
|
||||||
|
@ -223,43 +232,6 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
"""
|
"""
|
||||||
col.prop(self, self.blfields['extract_filter'], text='')
|
col.prop(self, self.blfields['extract_filter'], text='')
|
||||||
|
|
||||||
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
|
|
||||||
"""Draw dynamic information in the node, for user consideration.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
col: UI target for drawing.
|
|
||||||
"""
|
|
||||||
has_sim_data = self.sim_data_monitor_nametype is not None
|
|
||||||
has_monitor_data = self.monitor_data_components is not None
|
|
||||||
|
|
||||||
if has_sim_data or has_monitor_data:
|
|
||||||
# Header
|
|
||||||
row = col.row()
|
|
||||||
row.alignment = 'CENTER'
|
|
||||||
if has_sim_data:
|
|
||||||
row.label(text=f'{len(self.sim_data_monitor_nametype)} Monitors')
|
|
||||||
elif has_monitor_data:
|
|
||||||
row.label(text=f'{self.monitor_data_type} Monitor Data')
|
|
||||||
|
|
||||||
# Monitor Data Contents
|
|
||||||
## TODO: More compact double-split
|
|
||||||
## TODO: Output shape data.
|
|
||||||
## TODO: Local ENUM_MANY tabs for visible column selection?
|
|
||||||
row = col.row()
|
|
||||||
box = row.box()
|
|
||||||
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
|
|
||||||
if has_sim_data:
|
|
||||||
for (
|
|
||||||
monitor_name,
|
|
||||||
monitor_type,
|
|
||||||
) in self.sim_data_monitor_nametype.items():
|
|
||||||
grid.label(text=monitor_name)
|
|
||||||
grid.label(text=monitor_type.replace('Data', ''))
|
|
||||||
elif has_monitor_data:
|
|
||||||
for component_name in self.monitor_data_components:
|
|
||||||
grid.label(text=component_name)
|
|
||||||
grid.label(text=self.monitor_data_type)
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Events
|
# - Events
|
||||||
####################
|
####################
|
||||||
|
@ -416,9 +388,8 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
else:
|
else:
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
info_output_names = {
|
info_output_name = props['extract_filter']
|
||||||
'output_names': [props['extract_filter']],
|
info_output_shape = None
|
||||||
}
|
|
||||||
|
|
||||||
# Compute InfoFlow from XArray
|
# Compute InfoFlow from XArray
|
||||||
## XYZF: Field / Permittivity / FieldProjectionCartesian
|
## XYZF: Field / Permittivity / FieldProjectionCartesian
|
||||||
|
@ -442,13 +413,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
|
output_shape=None,
|
||||||
output_units={
|
output_mathtype=spux.MathType.Complex,
|
||||||
props['extract_filter']: spu.volt / spu.micrometer
|
output_unit=(
|
||||||
|
spu.volt / spu.micrometer
|
||||||
if props['monitor_data_type'] == 'Field'
|
if props['monitor_data_type'] == 'Field'
|
||||||
else None
|
else None
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
## XYZT: FieldTime
|
## XYZT: FieldTime
|
||||||
|
@ -468,17 +440,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
|
output_shape=None,
|
||||||
output_units={
|
output_mathtype=spux.MathType.Complex,
|
||||||
props['extract_filter']: (
|
output_unit=(
|
||||||
spu.volt / spu.micrometer
|
spu.volt / spu.micrometer
|
||||||
if props['extract_filter'].startswith('E')
|
|
||||||
else spu.ampere / spu.micrometer
|
|
||||||
)
|
|
||||||
if props['monitor_data_type'] == 'Field'
|
if props['monitor_data_type'] == 'Field'
|
||||||
else None
|
else None
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
## F: Flux
|
## F: Flux
|
||||||
|
@ -492,9 +461,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
output_shape=None,
|
||||||
output_units={props['extract_filter']: spu.watt},
|
output_mathtype=spux.MathType.Real,
|
||||||
|
output_unit=spu.watt,
|
||||||
)
|
)
|
||||||
|
|
||||||
## T: FluxTime
|
## T: FluxTime
|
||||||
|
@ -508,9 +478,10 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
output_shape=None,
|
||||||
output_units={props['extract_filter']: spu.watt},
|
output_mathtype=spux.MathType.Real,
|
||||||
|
output_unit=spu.watt,
|
||||||
)
|
)
|
||||||
|
|
||||||
## RThetaPhiF: FieldProjectionAngle
|
## RThetaPhiF: FieldProjectionAngle
|
||||||
|
@ -537,15 +508,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
output_shape=None,
|
||||||
output_units={
|
output_mathtype=spux.MathType.Real,
|
||||||
props['extract_filter']: (
|
output_unit=(
|
||||||
spu.volt / spu.micrometer
|
spu.volt / spu.micrometer
|
||||||
if props['extract_filter'].startswith('E')
|
if props['extract_filter'].startswith('E')
|
||||||
else spu.ampere / spu.micrometer
|
else spu.ampere / spu.micrometer
|
||||||
)
|
),
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
## UxUyRF: FieldProjectionKSpace
|
## UxUyRF: FieldProjectionKSpace
|
||||||
|
@ -570,15 +540,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
output_shape=None,
|
||||||
output_units={
|
output_mathtype=spux.MathType.Real,
|
||||||
props['extract_filter']: (
|
output_unit=(
|
||||||
spu.volt / spu.micrometer
|
spu.volt / spu.micrometer
|
||||||
if props['extract_filter'].startswith('E')
|
if props['extract_filter'].startswith('E')
|
||||||
else spu.ampere / spu.micrometer
|
else spu.ampere / spu.micrometer
|
||||||
)
|
),
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
## OrderxOrderyF: Diffraction
|
## OrderxOrderyF: Diffraction
|
||||||
|
@ -600,15 +569,14 @@ class ExtractDataNode(base.MaxwellSimNode):
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
**info_output_names,
|
output_name=props['extract_filter'],
|
||||||
output_mathtypes={props['extract_filter']: spux.MathType.Real},
|
output_shape=None,
|
||||||
output_units={
|
output_mathtype=spux.MathType.Real,
|
||||||
props['extract_filter']: (
|
output_unit=(
|
||||||
spu.volt / spu.micrometer
|
spu.volt / spu.micrometer
|
||||||
if props['extract_filter'].startswith('E')
|
if props['extract_filter'].startswith('E')
|
||||||
else spu.ampere / spu.micrometer
|
else spu.ampere / spu.micrometer
|
||||||
)
|
),
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'
|
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
from . import filter_math, map_math, operate_math, reduce_math
|
from . import filter_math, map_math, operate_math, reduce_math, transform_math
|
||||||
|
|
||||||
BL_REGISTER = [
|
BL_REGISTER = [
|
||||||
*map_math.BL_REGISTER,
|
*map_math.BL_REGISTER,
|
||||||
*filter_math.BL_REGISTER,
|
*filter_math.BL_REGISTER,
|
||||||
*reduce_math.BL_REGISTER,
|
*reduce_math.BL_REGISTER,
|
||||||
*operate_math.BL_REGISTER,
|
*operate_math.BL_REGISTER,
|
||||||
|
*transform_math.BL_REGISTER,
|
||||||
]
|
]
|
||||||
BL_NODES = {
|
BL_NODES = {
|
||||||
**map_math.BL_NODES,
|
**map_math.BL_NODES,
|
||||||
**filter_math.BL_NODES,
|
**filter_math.BL_NODES,
|
||||||
**reduce_math.BL_NODES,
|
**reduce_math.BL_NODES,
|
||||||
**operate_math.BL_NODES,
|
**operate_math.BL_NODES,
|
||||||
|
**transform_math.BL_NODES,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Declares `FilterMathNode`."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
|
@ -15,11 +17,21 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FilterMathNode(base.MaxwellSimNode):
|
class FilterMathNode(base.MaxwellSimNode):
|
||||||
"""Reduces the dimensionality of data.
|
r"""Applies a function that operates on the shape of the array.
|
||||||
|
|
||||||
|
The shape, type, and interpretation of the input/output data is dynamically shown.
|
||||||
|
|
||||||
|
# Socket Sets
|
||||||
|
## Dimensions
|
||||||
|
Alter the dimensions of the array.
|
||||||
|
|
||||||
|
## Interpret
|
||||||
|
Only alter the interpretation of the array data, which guides what it can be used for.
|
||||||
|
|
||||||
|
These operations are **zero cost**, since the data itself is untouched.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
operation: Operation to apply to the input.
|
operation: Operation to apply to the input.
|
||||||
dim: Dims to use when filtering data
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_type = ct.NodeType.FilterMath
|
node_type = ct.NodeType.FilterMath
|
||||||
|
@ -29,8 +41,8 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
'Data': sockets.DataSocketDef(format='jax'),
|
'Data': sockets.DataSocketDef(format='jax'),
|
||||||
}
|
}
|
||||||
input_socket_sets: typ.ClassVar = {
|
input_socket_sets: typ.ClassVar = {
|
||||||
'By Dim': {},
|
'Interpret': {},
|
||||||
'By Dim Value': {},
|
'Dimensions': {},
|
||||||
}
|
}
|
||||||
output_sockets: typ.ClassVar = {
|
output_sockets: typ.ClassVar = {
|
||||||
'Data': sockets.DataSocketDef(format='jax'),
|
'Data': sockets.DataSocketDef(format='jax'),
|
||||||
|
@ -43,10 +55,17 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||||
)
|
)
|
||||||
|
|
||||||
dim: enum.Enum = bl_cache.BLField(
|
# Dimension Selection
|
||||||
|
dim_0: enum.Enum = bl_cache.BLField(
|
||||||
|
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
||||||
|
)
|
||||||
|
dim_1: enum.Enum = bl_cache.BLField(
|
||||||
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Computed
|
||||||
|
####################
|
||||||
@property
|
@property
|
||||||
def data_info(self) -> ct.InfoFlow | None:
|
def data_info(self) -> ct.InfoFlow | None:
|
||||||
info = self._compute_input('Data', kind=ct.FlowKind.Info)
|
info = self._compute_input('Data', kind=ct.FlowKind.Info)
|
||||||
|
@ -60,87 +79,119 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
def search_operations(self) -> list[tuple[str, str, str]]:
|
||||||
items = []
|
items = []
|
||||||
if self.active_socket_set == 'By Dim':
|
if self.active_socket_set == 'Interpret':
|
||||||
items += [
|
items += [
|
||||||
('SQUEEZE', 'del a | #=1', 'Squeeze'),
|
('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'),
|
||||||
|
('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'),
|
||||||
]
|
]
|
||||||
if self.active_socket_set == 'By Dim Value':
|
elif self.active_socket_set == 'Dimensions':
|
||||||
items += [
|
items += [
|
||||||
('FIX', 'del a | i≈v', 'Fix Coordinate'),
|
('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'),
|
||||||
|
(
|
||||||
|
'PIN',
|
||||||
|
'pinₐ ≈v',
|
||||||
|
'Remove a len(n) dimension by selecting an index',
|
||||||
|
),
|
||||||
|
('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'),
|
||||||
]
|
]
|
||||||
|
|
||||||
return [(*item, '', i) for i, item in enumerate(items)]
|
return [(*item, '', i) for i, item in enumerate(items)]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Dim Search
|
# - Dimensions Search
|
||||||
####################
|
####################
|
||||||
def search_dims(self) -> list[ct.BLEnumElement]:
|
def search_dims(self) -> list[ct.BLEnumElement]:
|
||||||
if self.data_info is not None:
|
if self.data_info is None:
|
||||||
dims = [
|
|
||||||
(dim_name, dim_name, dim_name, '', i)
|
|
||||||
for i, dim_name in enumerate(self.data_info.dim_names)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Squeeze: Dimension Must Have Length=1
|
|
||||||
## We must also correct the "NUMBER" of the enum.
|
|
||||||
if self.operation == 'SQUEEZE':
|
|
||||||
filtered_dims = [
|
|
||||||
dim for dim in dims if self.data_info.dim_lens[dim[0]] == 1
|
|
||||||
]
|
|
||||||
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
|
|
||||||
|
|
||||||
return dims
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if self.operation == 'PIN_LEN_ONE':
|
||||||
|
dims = [
|
||||||
|
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
|
||||||
|
for dim_name in self.data_info.dim_names
|
||||||
|
if self.data_info.dim_lens[dim_name] == 1
|
||||||
|
]
|
||||||
|
elif self.operation in ['PIN', 'SWAP']:
|
||||||
|
dims = [
|
||||||
|
(dim_name, dim_name, f'Dimension "{dim_name}"')
|
||||||
|
for dim_name in self.data_info.dim_names
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [(*dim, '', i) for i, dim in enumerate(dims)]
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
|
def draw_label(self):
|
||||||
|
labels = {
|
||||||
|
'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)',
|
||||||
|
'PIN': lambda: f'Filter: Pin {self.dim_0}',
|
||||||
|
'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
|
||||||
|
'DIM_TO_VEC': lambda: 'Filter: -> Vector',
|
||||||
|
'DIMS_TO_MAT': lambda: 'Filter: -> Matrix',
|
||||||
|
}
|
||||||
|
|
||||||
|
if (label := labels.get(self.operation)) is not None:
|
||||||
|
return label()
|
||||||
|
|
||||||
|
return self.bl_label
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
layout.prop(self, self.blfields['operation'], text='')
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
if self.data_info is not None and self.data_info.dim_names:
|
|
||||||
layout.prop(self, self.blfields['dim'], text='')
|
if self.active_socket_set == 'Dimensions':
|
||||||
|
if self.operation in ['PIN_LEN_ONE', 'PIN']:
|
||||||
|
layout.prop(self, self.blfields['dim_0'], text='')
|
||||||
|
|
||||||
|
if self.operation == 'SWAP':
|
||||||
|
row = layout.row(align=True)
|
||||||
|
row.prop(self, self.blfields['dim_0'], text='')
|
||||||
|
row.prop(self, self.blfields['dim_1'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Events
|
# - Events
|
||||||
####################
|
####################
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
socket_name='Data',
|
|
||||||
prop_name='active_socket_set',
|
prop_name='active_socket_set',
|
||||||
run_on_init=True,
|
run_on_init=True,
|
||||||
input_sockets={'Data'},
|
|
||||||
)
|
)
|
||||||
def on_any_change(self, input_sockets: dict):
|
def on_socket_set_changed(self):
|
||||||
if all(
|
|
||||||
not ct.FlowSignal.check_single(
|
|
||||||
input_socket_value, ct.FlowSignal.FlowPending
|
|
||||||
)
|
|
||||||
for input_socket_value in input_sockets.values()
|
|
||||||
):
|
|
||||||
self.operation = bl_cache.Signal.ResetEnumItems
|
self.operation = bl_cache.Signal.ResetEnumItems
|
||||||
self.dim = bl_cache.Signal.ResetEnumItems
|
|
||||||
|
@events.on_value_changed(
|
||||||
|
# Trigger
|
||||||
|
socket_name='Data',
|
||||||
|
prop_name={'active_socket_set', 'operation'},
|
||||||
|
run_on_init=True,
|
||||||
|
# Loaded
|
||||||
|
props={'operation'},
|
||||||
|
)
|
||||||
|
def on_any_change(self, props: dict) -> None:
|
||||||
|
self.dim_0 = bl_cache.Signal.ResetEnumItems
|
||||||
|
self.dim_1 = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
socket_name='Data',
|
socket_name='Data',
|
||||||
prop_name='dim',
|
prop_name={'dim_0', 'dim_1', 'operation'},
|
||||||
## run_on_init: Implicitly triggered.
|
## run_on_init: Implicitly triggered.
|
||||||
props={'active_socket_set', 'dim'},
|
props={'operation', 'dim_0', 'dim_1'},
|
||||||
input_sockets={'Data'},
|
input_sockets={'Data'},
|
||||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
def on_dim_change(self, props: dict, input_sockets: dict):
|
def on_dim_change(self, props: dict, input_sockets: dict):
|
||||||
if input_sockets['Data'] == ct.FlowSignal.FlowPending:
|
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||||
|
if not has_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add/Remove Input Socket "Value"
|
# "Dimensions"|"PIN": Add/Remove Input Socket
|
||||||
if (
|
if props['operation'] == 'PIN' and props['dim_0'] != 'NONE':
|
||||||
not ct.Flowsignal.check(input_sockets['Data'])
|
|
||||||
and props['active_socket_set'] == 'By Dim Value'
|
|
||||||
and props['dim'] != 'NONE'
|
|
||||||
):
|
|
||||||
# Get Current and Wanted Socket Defs
|
# Get Current and Wanted Socket Defs
|
||||||
current_bl_socket = self.loose_input_sockets.get('Value')
|
current_bl_socket = self.loose_input_sockets.get('Value')
|
||||||
wanted_socket_def = sockets.SOCKET_DEFS[
|
wanted_socket_def = sockets.SOCKET_DEFS[
|
||||||
ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit)
|
ct.unit_to_socket_type(
|
||||||
|
input_sockets['Data'].dim_idx[props['dim_0']].unit
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Determine Whether to Declare New Loose Input SOcket
|
# Determine Whether to Declare New Loose Input SOcket
|
||||||
|
@ -151,7 +202,7 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
):
|
):
|
||||||
self.loose_input_sockets = {
|
self.loose_input_sockets = {
|
||||||
'Value': wanted_socket_def(),
|
'Value': wanted_socket_def(),
|
||||||
} ## TODO: Can we do the boilerplate in base.py?
|
}
|
||||||
elif self.loose_input_sockets:
|
elif self.loose_input_sockets:
|
||||||
self.loose_input_sockets = {}
|
self.loose_input_sockets = {}
|
||||||
|
|
||||||
|
@ -161,40 +212,51 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Data',
|
'Data',
|
||||||
kind=ct.FlowKind.LazyValueFunc,
|
kind=ct.FlowKind.LazyValueFunc,
|
||||||
props={'active_socket_set', 'operation', 'dim'},
|
props={'operation', 'dim_0', 'dim_1'},
|
||||||
input_sockets={'Data'},
|
input_sockets={'Data'},
|
||||||
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
|
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
|
||||||
)
|
)
|
||||||
def compute_data(self, props: dict, input_sockets: dict):
|
def compute_data(self, props: dict, input_sockets: dict):
|
||||||
# Retrieve Inputs
|
|
||||||
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||||
info = input_sockets['Data'][ct.FlowKind.Info]
|
info = input_sockets['Data'][ct.FlowKind.Info]
|
||||||
|
|
||||||
# Check Flow
|
# Check Flow
|
||||||
if (
|
if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]):
|
||||||
any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func])
|
|
||||||
or props['operation'] == 'NONE'
|
|
||||||
):
|
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Compute Bound/Free Parameters
|
# Compute Function Arguments
|
||||||
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
|
operation = props['operation']
|
||||||
axis = info.dim_names.index(props['dim'])
|
if operation == 'NONE':
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Select Function
|
## Dimension(s)
|
||||||
filter_func: typ.Callable[[jax.Array], jax.Array] = {
|
dim_0 = props['dim_0']
|
||||||
'By Dim': {'SQUEEZE': lambda data: jnp.squeeze(data, axis)},
|
dim_1 = props['dim_1']
|
||||||
'By Dim Value': {
|
if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||||
'FIX': lambda data, fixed_axis_idx: jnp.take(
|
return ct.FlowSignal.FlowPending
|
||||||
data, fixed_axis_idx, axis=axis
|
if operation == 'SWAP' and dim_1 == 'NONE':
|
||||||
)
|
return ct.FlowSignal.FlowPending
|
||||||
},
|
|
||||||
}[props['active_socket_set']][props['operation']]
|
## Axis/Axes
|
||||||
|
axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None
|
||||||
|
axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None
|
||||||
|
|
||||||
|
# Compose Output Function
|
||||||
|
filter_func = {
|
||||||
|
# Dimensions
|
||||||
|
'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0),
|
||||||
|
'PIN': lambda data, fixed_axis_idx: jnp.take(
|
||||||
|
data, fixed_axis_idx, axis=axis_0
|
||||||
|
),
|
||||||
|
'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1),
|
||||||
|
# Interpret
|
||||||
|
'DIM_TO_VEC': lambda data: data,
|
||||||
|
'DIMS_TO_MAT': lambda data: data,
|
||||||
|
}[props['operation']]
|
||||||
|
|
||||||
# Compose Function for Output
|
|
||||||
return lazy_value_func.compose_within(
|
return lazy_value_func.compose_within(
|
||||||
filter_func,
|
filter_func,
|
||||||
enclosing_func_args=func_args,
|
enclosing_func_args=[int] if operation == 'PIN' else [],
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -207,7 +269,6 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||||
# Retrieve Inputs
|
|
||||||
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||||
params = output_sockets['Data'][ct.FlowKind.Params]
|
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
@ -215,57 +276,54 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Compute Array
|
|
||||||
return ct.ArrayFlow(
|
return ct.ArrayFlow(
|
||||||
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs),
|
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs),
|
||||||
unit=None, ## TODO: Unit Propagation
|
unit=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Compute Auxiliary: Info / Params
|
# - Compute Auxiliary: Info
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Data',
|
'Data',
|
||||||
kind=ct.FlowKind.Info,
|
kind=ct.FlowKind.Info,
|
||||||
props={'active_socket_set', 'dim', 'operation'},
|
props={'dim_0', 'dim_1', 'operation'},
|
||||||
input_sockets={'Data'},
|
input_sockets={'Data'},
|
||||||
input_socket_kinds={'Data': ct.FlowKind.Info},
|
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||||
)
|
)
|
||||||
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||||
# Retrieve Inputs
|
|
||||||
info = input_sockets['Data']
|
info = input_sockets['Data']
|
||||||
|
|
||||||
# Check Flow
|
# Check Flow
|
||||||
if ct.FlowSignal.check(info) or props['dim'] == 'NONE':
|
if ct.FlowSignal.check(info):
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Compute Information
|
# Collect Information
|
||||||
## Compute Info w/By-Operation Change to Dimensions
|
dim_0 = props['dim_0']
|
||||||
axis = info.dim_names.index(props['dim'])
|
dim_1 = props['dim_1']
|
||||||
|
|
||||||
if (props['active_socket_set'], props['operation']) in [
|
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||||
('By Dim', 'SQUEEZE'),
|
return ct.FlowSignal.FlowPending
|
||||||
('By Dim Value', 'FIX'),
|
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
|
||||||
] and info.dim_names:
|
return ct.FlowSignal.FlowPending
|
||||||
return ct.InfoFlow(
|
|
||||||
dim_names=info.dim_names[:axis] + info.dim_names[axis + 1 :],
|
|
||||||
dim_idx={
|
|
||||||
dim_name: dim_idx
|
|
||||||
for dim_name, dim_idx in info.dim_idx.items()
|
|
||||||
if dim_name != props['dim']
|
|
||||||
},
|
|
||||||
output_names=info.output_names,
|
|
||||||
output_mathtypes=info.output_mathtypes,
|
|
||||||
output_units=info.output_units,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = f'Active socket set {props["active_socket_set"]} and operation {props["operation"]} don\'t have an InfoFlow defined'
|
return {
|
||||||
raise RuntimeError(msg)
|
# Dimensions
|
||||||
|
'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0),
|
||||||
|
'PIN': lambda: info.delete_dimension(dim_0),
|
||||||
|
'SWAP': lambda: info.swap_dimensions(dim_0, dim_1),
|
||||||
|
# Interpret
|
||||||
|
'DIM_TO_VEC': lambda: info.shift_last_input,
|
||||||
|
'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input,
|
||||||
|
}[props['operation']]()
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Compute Auxiliary: Info
|
||||||
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Data',
|
'Data',
|
||||||
kind=ct.FlowKind.Params,
|
kind=ct.FlowKind.Params,
|
||||||
props={'active_socket_set', 'dim', 'operation'},
|
props={'dim_0', 'dim_1', 'operation'},
|
||||||
input_sockets={'Data', 'Value'},
|
input_sockets={'Data', 'Value'},
|
||||||
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}},
|
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}},
|
||||||
input_sockets_optional={'Value': True},
|
input_sockets_optional={'Value': True},
|
||||||
|
@ -273,35 +331,33 @@ class FilterMathNode(base.MaxwellSimNode):
|
||||||
def compute_composed_params(
|
def compute_composed_params(
|
||||||
self, props: dict, input_sockets: dict
|
self, props: dict, input_sockets: dict
|
||||||
) -> ct.ParamsFlow:
|
) -> ct.ParamsFlow:
|
||||||
# Retrieve Inputs
|
|
||||||
info = input_sockets['Data'][ct.FlowKind.Info]
|
info = input_sockets['Data'][ct.FlowKind.Info]
|
||||||
params = input_sockets['Data'][ct.FlowKind.Params]
|
params = input_sockets['Data'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
# Check Flow
|
||||||
if any(ct.FlowSignal.check(inp) for inp in [info, params]):
|
if any(ct.FlowSignal.check(inp) for inp in [info, params]):
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Compute Composed Parameters
|
# Collect Information
|
||||||
## -> Only operations that add parameters.
|
## Dimensions
|
||||||
## -> A dimension must be selected.
|
dim_0 = props['dim_0']
|
||||||
## -> There must be an input value.
|
dim_1 = props['dim_1']
|
||||||
if (
|
|
||||||
(props['active_socket_set'], props['operation'])
|
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
|
||||||
in [
|
return ct.FlowSignal.FlowPending
|
||||||
('By Dim Value', 'FIX'),
|
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
|
||||||
]
|
return ct.FlowSignal.FlowPending
|
||||||
and props['dim'] != 'NONE'
|
|
||||||
and not ct.FlowSignal.check(input_sockets['Value'])
|
## Pinned Value
|
||||||
):
|
pinned_value = input_sockets['Value']
|
||||||
# Compute IDX Corresponding to Coordinate Value
|
has_pinned_value = not ct.FlowSignal.check(pinned_value)
|
||||||
## -> Each dimension declares a unit-aware real number at each index.
|
|
||||||
## -> "Value" is a unit-aware real number from loose input socket.
|
if props['operation'] == 'PIN' and has_pinned_value:
|
||||||
## -> This finds the dimensional index closest to "Value".
|
# Compute IDX Corresponding to Dimension Index
|
||||||
## Total Effect: Indexing by a unit-aware real number.
|
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
|
||||||
nearest_idx_to_value = info.dim_idx[props['dim']].nearest_idx_of(
|
|
||||||
input_sockets['Value'], require_sorted=True
|
input_sockets['Value'], require_sorted=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compose Parameters
|
|
||||||
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
|
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
|
@ -138,6 +138,7 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
('SQ', 'v²', 'v^2 (by el)'),
|
('SQ', 'v²', 'v^2 (by el)'),
|
||||||
('SQRT', '√v', 'sqrt(v) (by el)'),
|
('SQRT', '√v', 'sqrt(v) (by el)'),
|
||||||
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
|
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
|
||||||
|
None,
|
||||||
# Trigonometry
|
# Trigonometry
|
||||||
('COS', 'cos v', 'cos(v) (by el)'),
|
('COS', 'cos v', 'cos(v) (by el)'),
|
||||||
('SIN', 'sin v', 'sin(v) (by el)'),
|
('SIN', 'sin v', 'sin(v) (by el)'),
|
||||||
|
@ -148,6 +149,7 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
]
|
]
|
||||||
elif self.active_socket_set in 'By Vector':
|
elif self.active_socket_set in 'By Vector':
|
||||||
items = [
|
items = [
|
||||||
|
# Vector -> Number
|
||||||
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
|
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
|
||||||
]
|
]
|
||||||
elif self.active_socket_set == 'By Matrix':
|
elif self.active_socket_set == 'By Matrix':
|
||||||
|
@ -157,13 +159,16 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
('COND', 'κ(V)', 'cond(V) (by Mat)'),
|
('COND', 'κ(V)', 'cond(V) (by Mat)'),
|
||||||
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
|
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
|
||||||
('RANK', 'rank V', 'rank(V) (by Mat)'),
|
('RANK', 'rank V', 'rank(V) (by Mat)'),
|
||||||
|
None,
|
||||||
# Matrix -> Array
|
# Matrix -> Array
|
||||||
('DIAG', 'diag V', 'diag(V) (by Mat)'),
|
('DIAG', 'diag V', 'diag(V) (by Mat)'),
|
||||||
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
|
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
|
||||||
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
|
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
|
||||||
|
None,
|
||||||
# Matrix -> Matrix
|
# Matrix -> Matrix
|
||||||
('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
|
('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
|
||||||
('TRA', 'Vt', 'V^T (by Mat)'),
|
('TRA', 'Vt', 'V^T (by Mat)'),
|
||||||
|
None,
|
||||||
# Matrix -> Matrices
|
# Matrix -> Matrices
|
||||||
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
|
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
|
||||||
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
|
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
|
||||||
|
@ -175,7 +180,9 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
msg = f'Active socket set {self.active_socket_set} is unknown'
|
msg = f'Active socket set {self.active_socket_set} is unknown'
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
return [(*item, '', i) for i, item in enumerate(items)]
|
return [
|
||||||
|
(*item, '', i) if item is not None else None for i, item in enumerate(items)
|
||||||
|
]
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
layout.prop(self, self.blfields['operation'], text='')
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
|
@ -185,8 +192,9 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
####################
|
####################
|
||||||
@events.on_value_changed(
|
@events.on_value_changed(
|
||||||
prop_name='active_socket_set',
|
prop_name='active_socket_set',
|
||||||
|
run_on_init=True,
|
||||||
)
|
)
|
||||||
def on_operation_changed(self):
|
def on_socket_set_changed(self):
|
||||||
self.operation = bl_cache.Signal.ResetEnumItems
|
self.operation = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -204,11 +212,14 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
input_sockets_optional={'Mapper': True},
|
input_sockets_optional={'Mapper': True},
|
||||||
)
|
)
|
||||||
def compute_data(self, props: dict, input_sockets: dict):
|
def compute_data(self, props: dict, input_sockets: dict):
|
||||||
|
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||||
if (
|
if (
|
||||||
ct.FlowSignal.check(input_sockets['Data']) or props['operation'] == 'NONE'
|
not has_data
|
||||||
) or (
|
or props['operation'] == 'NONE'
|
||||||
|
or (
|
||||||
props['active_socket_set'] == 'Expr'
|
props['active_socket_set'] == 'Expr'
|
||||||
and ct.FlowSignal.check(input_sockets['Mapper'])
|
and ct.FlowSignal.check(input_sockets['Mapper'])
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
@ -238,14 +249,14 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
|
||||||
'RANK': lambda data: jnp.linalg.matrix_rank(data),
|
'RANK': lambda data: jnp.linalg.matrix_rank(data),
|
||||||
# Matrix -> Vec
|
# Matrix -> Vec
|
||||||
'DIAG': lambda data: jnp.diag(data),
|
'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
|
||||||
'EIG_VALS': lambda data: jnp.eigvals(data),
|
'EIG_VALS': lambda data: jnp.linalg.eigvals(data),
|
||||||
'SVD_VALS': lambda data: jnp.svdvals(data),
|
'SVD_VALS': lambda data: jnp.linalg.svdvals(data),
|
||||||
# Matrix -> Matrix
|
# Matrix -> Matrix
|
||||||
'INV': lambda data: jnp.inv(data),
|
'INV': lambda data: jnp.linalg.inv(data),
|
||||||
'TRA': lambda data: jnp.matrix_transpose(data),
|
'TRA': lambda data: jnp.matrix_transpose(data),
|
||||||
# Matrix -> Matrices
|
# Matrix -> Matrices
|
||||||
'QR': lambda data: jnp.inv(data),
|
'QR': lambda data: jnp.linalg.qr(data),
|
||||||
'CHOL': lambda data: jnp.linalg.cholesky(data),
|
'CHOL': lambda data: jnp.linalg.cholesky(data),
|
||||||
'SVD': lambda data: jnp.linalg.svd(data),
|
'SVD': lambda data: jnp.linalg.svd(data),
|
||||||
},
|
},
|
||||||
|
@ -298,28 +309,53 @@ class MapMathNode(base.MaxwellSimNode):
|
||||||
return ct.FlowSignal.FlowPending
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
# Complex -> Real
|
# Complex -> Real
|
||||||
if props['active_socket_set'] == 'By Element':
|
if props['active_socket_set'] == 'By Element' and props['operation'] in [
|
||||||
if props['operation'] in [
|
|
||||||
'REAL',
|
'REAL',
|
||||||
'IMAG',
|
'IMAG',
|
||||||
'ABS',
|
'ABS',
|
||||||
]:
|
]:
|
||||||
return ct.InfoFlow(
|
return info.set_output_mathtype(spux.MathType.Real)
|
||||||
dim_names=info.dim_names,
|
|
||||||
dim_idx=info.dim_idx,
|
if props['active_socket_set'] == 'By Vector' and props['operation'] in [
|
||||||
output_names=info.output_names,
|
'NORM_2'
|
||||||
output_mathtypes={
|
]:
|
||||||
output_name: (
|
return {
|
||||||
spux.MathType.Real
|
'NORM_2': lambda: info.collapse_output(
|
||||||
if output_mathtype == spux.MathType.Complex
|
collapsed_name=f'||{info.output_name}||₂',
|
||||||
else output_mathtype
|
collapsed_mathtype=spux.MathType.Real,
|
||||||
|
collapsed_unit=info.output_unit,
|
||||||
)
|
)
|
||||||
for output_name, output_mathtype in info.output_mathtypes.items()
|
}[props['operation']]()
|
||||||
},
|
|
||||||
output_units=info.output_units,
|
if props['active_socket_set'] == 'By Matrix' and props['operation'] in [
|
||||||
)
|
'DET',
|
||||||
if props['active_socket_set'] == 'By Vector':
|
'COND',
|
||||||
pass
|
'NORM_FRO',
|
||||||
|
'RANK',
|
||||||
|
]:
|
||||||
|
return {
|
||||||
|
'DET': lambda: info.collapse_output(
|
||||||
|
collapsed_name=f'det {info.output_name}',
|
||||||
|
collapsed_mathtype=info.output_mathtype,
|
||||||
|
collapsed_unit=info.output_unit,
|
||||||
|
),
|
||||||
|
'COND': lambda: info.collapse_output(
|
||||||
|
collapsed_name=f'κ({info.output_name})',
|
||||||
|
collapsed_mathtype=spux.MathType.Real,
|
||||||
|
collapsed_unit=None,
|
||||||
|
),
|
||||||
|
'NORM_FRO': lambda: info.collapse_output(
|
||||||
|
collapsed_name=f'||({info.output_name}||_F',
|
||||||
|
collapsed_mathtype=spux.MathType.Real,
|
||||||
|
collapsed_unit=info.output_unit,
|
||||||
|
),
|
||||||
|
'RANK': lambda: info.collapse_output(
|
||||||
|
collapsed_name=f'rank {info.output_name}',
|
||||||
|
collapsed_mathtype=spux.MathType.Integer,
|
||||||
|
collapsed_unit=None,
|
||||||
|
),
|
||||||
|
}[props['operation']]()
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
|
import enum
|
||||||
import typing as typ
|
import typing as typ
|
||||||
|
|
||||||
import bpy
|
import bpy
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
from blender_maxwell.utils import logger
|
from blender_maxwell.utils import bl_cache, logger
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
|
||||||
from .... import contracts as ct
|
from .... import contracts as ct
|
||||||
from .... import sockets
|
from .... import sockets
|
||||||
|
@ -13,120 +16,465 @@ log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OperateMathNode(base.MaxwellSimNode):
|
class OperateMathNode(base.MaxwellSimNode):
|
||||||
|
r"""Applies a function that depends on two inputs.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
category: The category of operations to apply to the inputs.
|
||||||
|
**Only valid** categories can be chosen.
|
||||||
|
operation: The actual operation to apply to the inputs.
|
||||||
|
**Only valid** operations can be chosen.
|
||||||
|
"""
|
||||||
|
|
||||||
node_type = ct.NodeType.OperateMath
|
node_type = ct.NodeType.OperateMath
|
||||||
bl_label = 'Operate Math'
|
bl_label = 'Operate Math'
|
||||||
|
|
||||||
input_socket_sets: typ.ClassVar = {
|
input_socket_sets: typ.ClassVar = {
|
||||||
'Elementwise': {
|
'Expr | Expr': {
|
||||||
'Data L': sockets.AnySocketDef(),
|
'Expr L': sockets.ExprSocketDef(),
|
||||||
'Data R': sockets.AnySocketDef(),
|
'Expr R': sockets.ExprSocketDef(),
|
||||||
},
|
},
|
||||||
## TODO: Filter-array building operations
|
'Data | Data': {
|
||||||
'Vec-Vec': {
|
'Data L': sockets.DataSocketDef(
|
||||||
'Data L': sockets.AnySocketDef(),
|
format='jax', default_show_info_columns=False
|
||||||
'Data R': sockets.AnySocketDef(),
|
),
|
||||||
|
'Data R': sockets.DataSocketDef(
|
||||||
|
format='jax', default_show_info_columns=False
|
||||||
|
),
|
||||||
},
|
},
|
||||||
'Mat-Vec': {
|
'Expr | Data': {
|
||||||
'Data L': sockets.AnySocketDef(),
|
'Expr L': sockets.ExprSocketDef(),
|
||||||
'Data R': sockets.AnySocketDef(),
|
'Data R': sockets.DataSocketDef(
|
||||||
|
format='jax', default_show_info_columns=False
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
output_sockets: typ.ClassVar = {
|
output_socket_sets: typ.ClassVar = {
|
||||||
'Data': sockets.AnySocketDef(),
|
'Expr | Expr': {
|
||||||
|
'Expr': sockets.ExprSocketDef(),
|
||||||
|
},
|
||||||
|
'Data | Data': {
|
||||||
|
'Data': sockets.DataSocketDef(
|
||||||
|
format='jax', default_show_info_columns=False
|
||||||
|
),
|
||||||
|
},
|
||||||
|
'Expr | Data': {
|
||||||
|
'Data': sockets.DataSocketDef(
|
||||||
|
format='jax', default_show_info_columns=False
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Properties
|
||||||
####################
|
####################
|
||||||
operation: bpy.props.EnumProperty(
|
category: enum.Enum = bl_cache.BLField(
|
||||||
name='Op',
|
prop_ui=True, enum_cb=lambda self, _: self.search_categories()
|
||||||
description='Operation to apply to the two inputs',
|
|
||||||
items=lambda self, _: self.search_operations(),
|
|
||||||
update=lambda self, context: self.on_prop_changed('operation', context),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_operations(self) -> list[tuple[str, str, str]]:
|
operation: enum.Enum = bl_cache.BLField(
|
||||||
|
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_categories(self) -> list[ct.BLEnumElement]:
|
||||||
|
"""Deduce and return a list of valid categories for the current socket set and input data."""
|
||||||
|
data_l_info = self._compute_input(
|
||||||
|
'Data L', kind=ct.FlowKind.Info, optional=True
|
||||||
|
)
|
||||||
|
data_r_info = self._compute_input(
|
||||||
|
'Data R', kind=ct.FlowKind.Info, optional=True
|
||||||
|
)
|
||||||
|
|
||||||
|
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||||
|
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||||
|
|
||||||
|
# Categories by Socket Set
|
||||||
|
NUMBER_NUMBER = (
|
||||||
|
'Number | Number',
|
||||||
|
'Number | Number',
|
||||||
|
'Operations between numerical elements',
|
||||||
|
)
|
||||||
|
NUMBER_VECTOR = (
|
||||||
|
'Number | Vector',
|
||||||
|
'Number | Vector',
|
||||||
|
'Operations between numerical and vector elements',
|
||||||
|
)
|
||||||
|
NUMBER_MATRIX = (
|
||||||
|
'Number | Matrix',
|
||||||
|
'Number | Matrix',
|
||||||
|
'Operations between numerical and matrix elements',
|
||||||
|
)
|
||||||
|
VECTOR_VECTOR = (
|
||||||
|
'Vector | Vector',
|
||||||
|
'Vector | Vector',
|
||||||
|
'Operations between vector elements',
|
||||||
|
)
|
||||||
|
MATRIX_VECTOR = (
|
||||||
|
'Matrix | Vector',
|
||||||
|
'Matrix | Vector',
|
||||||
|
'Operations between vector and matrix elements',
|
||||||
|
)
|
||||||
|
MATRIX_MATRIX = (
|
||||||
|
'Matrix | Matrix',
|
||||||
|
'Matrix | Matrix',
|
||||||
|
'Operations between matrix elements',
|
||||||
|
)
|
||||||
|
categories = []
|
||||||
|
|
||||||
|
## Expr | Expr
|
||||||
|
if self.active_socket_set == 'Expr | Expr':
|
||||||
|
return [NUMBER_NUMBER]
|
||||||
|
|
||||||
|
## Data | Data
|
||||||
|
if (
|
||||||
|
self.active_socket_set == 'Data | Data'
|
||||||
|
and has_data_l_info
|
||||||
|
and has_data_r_info
|
||||||
|
):
|
||||||
|
# Check Valid Broadcasting
|
||||||
|
## Number | Number
|
||||||
|
if data_l_info.output_shape is None and data_r_info.output_shape is None:
|
||||||
|
categories = [NUMBER_NUMBER]
|
||||||
|
|
||||||
|
## Number | Vector
|
||||||
|
elif (
|
||||||
|
data_l_info.output_shape is None and len(data_r_info.output_shape) == 1
|
||||||
|
):
|
||||||
|
categories = [NUMBER_VECTOR]
|
||||||
|
|
||||||
|
## Number | Matrix
|
||||||
|
elif (
|
||||||
|
data_l_info.output_shape is None and len(data_r_info.output_shape) == 2
|
||||||
|
): # noqa: PLR2004
|
||||||
|
categories = [NUMBER_MATRIX]
|
||||||
|
|
||||||
|
## Vector | Vector
|
||||||
|
elif (
|
||||||
|
len(data_l_info.output_shape) == 1
|
||||||
|
and len(data_r_info.output_shape) == 1
|
||||||
|
):
|
||||||
|
categories = [VECTOR_VECTOR]
|
||||||
|
|
||||||
|
## Matrix | Vector
|
||||||
|
elif (
|
||||||
|
len(data_l_info.output_shape) == 2 # noqa: PLR2004
|
||||||
|
and len(data_r_info.output_shape) == 1
|
||||||
|
):
|
||||||
|
categories = [MATRIX_VECTOR]
|
||||||
|
|
||||||
|
## Matrix | Matrix
|
||||||
|
elif (
|
||||||
|
len(data_l_info.output_shape) == 2 # noqa: PLR2004
|
||||||
|
and len(data_r_info.output_shape) == 2 # noqa: PLR2004
|
||||||
|
):
|
||||||
|
categories = [MATRIX_MATRIX]
|
||||||
|
|
||||||
|
## Expr | Data
|
||||||
|
if self.active_socket_set == 'Expr | Data' and has_data_r_info:
|
||||||
|
if data_r_info.output_shape is None:
|
||||||
|
categories = [NUMBER_NUMBER]
|
||||||
|
else:
|
||||||
|
categories = {
|
||||||
|
1: [NUMBER_NUMBER, NUMBER_VECTOR],
|
||||||
|
2: [NUMBER_NUMBER, NUMBER_MATRIX],
|
||||||
|
}[len(data_r_info.output_shape)]
|
||||||
|
|
||||||
|
return [
|
||||||
|
(*category, '', i) if category is not None else None
|
||||||
|
for i, category in enumerate(categories)
|
||||||
|
]
|
||||||
|
|
||||||
|
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||||
items = []
|
items = []
|
||||||
if self.active_socket_set == 'Elementwise':
|
if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']:
|
||||||
items = [
|
items += [
|
||||||
('ADD', 'Add', 'L + R (by el)'),
|
('ADD', 'L + R', 'Add'),
|
||||||
('SUB', 'Subtract', 'L - R (by el)'),
|
('SUB', 'L - R', 'Subtract'),
|
||||||
('MUL', 'Multiply', 'L · R (by el)'),
|
('MUL', 'L · R', 'Multiply'),
|
||||||
('DIV', 'Divide', 'L ÷ R (by el)'),
|
('DIV', 'L ÷ R', 'Divide'),
|
||||||
('POW', 'Power', 'L^R (by el)'),
|
('POW', 'L^R', 'Power'),
|
||||||
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
|
('ATAN2', 'atan2(L,R)', 'atan2(L,R)'),
|
||||||
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
|
|
||||||
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
|
|
||||||
]
|
]
|
||||||
elif self.active_socket_set in 'Vec | Vec':
|
if self.category in 'Vector | Vector':
|
||||||
items = [
|
if items:
|
||||||
('DOT', 'Dot', 'L · R'),
|
items += [None]
|
||||||
('CROSS', 'Cross', 'L x R (by last-axis'),
|
items += [
|
||||||
|
('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'),
|
||||||
|
('CROSS', 'L x R', 'Cross Product'),
|
||||||
|
('PROJ', 'proj(L, R)', 'Projection'),
|
||||||
]
|
]
|
||||||
elif self.active_socket_set == 'Mat | Vec':
|
if self.category == 'Matrix | Vector':
|
||||||
items = [
|
if items:
|
||||||
('DOT', 'Dot', 'L · R'),
|
items += [None]
|
||||||
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
|
items += [
|
||||||
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
|
('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'),
|
||||||
|
('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'),
|
||||||
|
('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'),
|
||||||
|
]
|
||||||
|
if self.category == 'Matrix | Matrix':
|
||||||
|
if items:
|
||||||
|
items += [None]
|
||||||
|
items += [
|
||||||
|
('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
(*item, '', i) if item is not None else None for i, item in enumerate(items)
|
||||||
]
|
]
|
||||||
return items
|
|
||||||
|
|
||||||
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
layout.prop(self, 'operation')
|
layout.prop(self, self.blfields['category'], text='')
|
||||||
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# - Properties
|
# - Events
|
||||||
|
####################
|
||||||
|
@events.on_value_changed(
|
||||||
|
# Trigger
|
||||||
|
socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'},
|
||||||
|
prop_name='active_socket_set',
|
||||||
|
run_on_init=True,
|
||||||
|
)
|
||||||
|
def on_socket_set_changed(self) -> None:
|
||||||
|
# Recompute Valid Categories
|
||||||
|
self.category = bl_cache.Signal.ResetEnumItems
|
||||||
|
self.operation = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
|
@events.on_value_changed(
|
||||||
|
prop_name='category',
|
||||||
|
run_on_init=True,
|
||||||
|
)
|
||||||
|
def on_category_changed(self) -> None:
|
||||||
|
self.operation = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Output
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Expr',
|
||||||
|
kind=ct.FlowKind.Value,
|
||||||
|
props={'operation'},
|
||||||
|
input_sockets={'Expr L', 'Expr R'},
|
||||||
|
)
|
||||||
|
def compute_expr(self, props: dict, input_sockets: dict):
|
||||||
|
expr_l = input_sockets['Expr L']
|
||||||
|
expr_r = input_sockets['Expr R']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'ADD': lambda: expr_l + expr_r,
|
||||||
|
'SUB': lambda: expr_l - expr_r,
|
||||||
|
'MUL': lambda: expr_l * expr_r,
|
||||||
|
'DIV': lambda: expr_l / expr_r,
|
||||||
|
'POW': lambda: expr_l**expr_r,
|
||||||
|
'ATAN2': lambda: sp.atan2(expr_r, expr_l),
|
||||||
|
}[props['operation']]()
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.LazyValueFunc,
|
||||||
|
props={'operation'},
|
||||||
|
input_sockets={'Data L', 'Data R'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Data L': ct.FlowKind.LazyValueFunc,
|
||||||
|
'Data R': ct.FlowKind.LazyValueFunc,
|
||||||
|
},
|
||||||
|
input_sockets_optional={
|
||||||
|
'Data L': True,
|
||||||
|
'Data R': True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_data(self, props: dict, input_sockets: dict):
|
||||||
|
data_l = input_sockets['Data L']
|
||||||
|
data_r = input_sockets['Data R']
|
||||||
|
has_data_l = not ct.FlowSignal.check(data_l)
|
||||||
|
|
||||||
|
mapping_func = {
|
||||||
|
# Number | *
|
||||||
|
'ADD': lambda datas: datas[0] + datas[1],
|
||||||
|
'SUB': lambda datas: datas[0] - datas[1],
|
||||||
|
'MUL': lambda datas: datas[0] * datas[1],
|
||||||
|
'DIV': lambda datas: datas[0] / datas[1],
|
||||||
|
'POW': lambda datas: datas[0] ** datas[1],
|
||||||
|
'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]),
|
||||||
|
# Vector | Vector
|
||||||
|
'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||||
|
'CROSS': lambda datas: jnp.cross(datas[0], datas[1]),
|
||||||
|
# Matrix | Vector
|
||||||
|
'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||||
|
'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]),
|
||||||
|
'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]),
|
||||||
|
# Matrix | Matrix
|
||||||
|
'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
|
||||||
|
}[props['operation']]
|
||||||
|
|
||||||
|
# Compose by Socket Set
|
||||||
|
## Data | Data
|
||||||
|
if has_data_l:
|
||||||
|
return (data_l | data_r).compose_within(
|
||||||
|
mapping_func,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expr | Data
|
||||||
|
expr_l_lazy_value_func = ct.LazyValueFuncFlow(
|
||||||
|
func=lambda expr_l_value: expr_l_value,
|
||||||
|
func_args=[typ.Any],
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
return (expr_l_lazy_value_func | data_r).compose_within(
|
||||||
|
mapping_func,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.Array,
|
||||||
|
output_sockets={'Data'},
|
||||||
|
output_socket_kinds={
|
||||||
|
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||||
|
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||||
|
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
|
||||||
|
has_params = not ct.FlowSignal.check(params)
|
||||||
|
|
||||||
|
if has_lazy_value_func and has_params:
|
||||||
|
return ct.ArrayFlow(
|
||||||
|
values=lazy_value_func.func_jax(
|
||||||
|
*params.func_args, **params.func_kwargs
|
||||||
|
),
|
||||||
|
unit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Auxiliary: Params
|
||||||
####################
|
####################
|
||||||
@events.computes_output_socket(
|
@events.computes_output_socket(
|
||||||
'Data',
|
'Data',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
props={'operation'},
|
props={'operation'},
|
||||||
input_sockets={'Data L', 'Data R'},
|
input_sockets={'Expr L', 'Data L', 'Data R'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Expr L': ct.FlowKind.Value,
|
||||||
|
'Data L': {ct.FlowKind.Info, ct.FlowKind.Params},
|
||||||
|
'Data R': {ct.FlowKind.Info, ct.FlowKind.Params},
|
||||||
|
},
|
||||||
|
input_sockets_optional={
|
||||||
|
'Expr L': True,
|
||||||
|
'Data L': True,
|
||||||
|
'Data R': True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
def compute_data(self, props: dict, input_sockets: dict):
|
def compute_data_params(
|
||||||
if self.active_socket_set == 'Elementwise':
|
self, props, input_sockets
|
||||||
# Element-Wise Arithmetic
|
) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
if props['operation'] == 'ADD':
|
expr_l = input_sockets['Expr L']
|
||||||
return input_sockets['Data L'] + input_sockets['Data R']
|
data_l_info = input_sockets['Data L'][ct.FlowKind.Info]
|
||||||
if props['operation'] == 'SUB':
|
data_l_params = input_sockets['Data L'][ct.FlowKind.Params]
|
||||||
return input_sockets['Data L'] - input_sockets['Data R']
|
data_r_info = input_sockets['Data R'][ct.FlowKind.Info]
|
||||||
if props['operation'] == 'MUL':
|
data_r_params = input_sockets['Data R'][ct.FlowKind.Params]
|
||||||
return input_sockets['Data L'] * input_sockets['Data R']
|
|
||||||
if props['operation'] == 'DIV':
|
|
||||||
return input_sockets['Data L'] / input_sockets['Data R']
|
|
||||||
|
|
||||||
# Element-Wise Arithmetic
|
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||||
if props['operation'] == 'POW':
|
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||||
return input_sockets['Data L'] ** input_sockets['Data R']
|
has_data_l_params = not ct.FlowSignal.check(data_l_params)
|
||||||
|
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||||
|
has_data_r_params = not ct.FlowSignal.check(data_r_params)
|
||||||
|
|
||||||
# Binary Trigonometry
|
#log.critical((props, input_sockets))
|
||||||
if props['operation'] == 'ATAN2':
|
|
||||||
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
|
|
||||||
|
|
||||||
# Special Functions
|
# Compose by Socket Set
|
||||||
if props['operation'] == 'HEAVISIDE':
|
## Data | Data
|
||||||
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
|
if (
|
||||||
|
has_data_l_info
|
||||||
|
and has_data_l_params
|
||||||
|
and has_data_r_info
|
||||||
|
and has_data_r_params
|
||||||
|
):
|
||||||
|
return data_l_params | data_r_params
|
||||||
|
|
||||||
# Linear Algebra
|
## Expr | Data
|
||||||
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
|
if has_expr_l and has_data_r_info and has_data_r_params:
|
||||||
if props['operation'] == 'DOT':
|
operation = props['operation']
|
||||||
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
|
data_unit = data_r_info.output_unit
|
||||||
|
|
||||||
elif self.active_socket_set == 'Vec-Vec':
|
# By Operation
|
||||||
if props['operation'] == 'CROSS':
|
## Add/Sub: Scale to Output Unit
|
||||||
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
|
if operation in ['ADD', 'SUB', 'MUL', 'DIV']:
|
||||||
|
if not spux.uses_units(expr_l):
|
||||||
|
value = spux.sympy_to_python(expr_l)
|
||||||
|
else:
|
||||||
|
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||||
|
|
||||||
elif self.active_socket_set == 'Mat-Vec':
|
return data_r_params.compose_within(
|
||||||
if props['operation'] == 'LIN_SOLVE':
|
enclosing_func_args=[value],
|
||||||
return jnp.linalg.lstsq(
|
|
||||||
input_sockets['Data L'], input_sockets['Data R']
|
|
||||||
)
|
|
||||||
if props['operation'] == 'LSQ_SOLVE':
|
|
||||||
return jnp.linalg.solve(
|
|
||||||
input_sockets['Data L'], input_sockets['Data R']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = 'Invalid operation'
|
## Pow: Doesn't Exist (?)
|
||||||
raise ValueError(msg)
|
## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function
|
||||||
|
if operation == 'POW':
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
## atan2(): Only Length
|
||||||
|
## -> Implicitly presume that Data L/R use length units.
|
||||||
|
if operation == 'ATAN2':
|
||||||
|
if not spux.uses_units(expr_l):
|
||||||
|
value = spux.sympy_to_python(expr_l)
|
||||||
|
else:
|
||||||
|
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||||
|
|
||||||
|
return data_r_params.compose_within(
|
||||||
|
enclosing_func_args=[value],
|
||||||
|
)
|
||||||
|
|
||||||
|
return data_r_params.compose_within(
|
||||||
|
enclosing_func_args=[
|
||||||
|
spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Auxiliary: Info
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.Info,
|
||||||
|
input_sockets={'Expr L', 'Data L', 'Data R'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Expr L': ct.FlowKind.Value,
|
||||||
|
'Data L': ct.FlowKind.Info,
|
||||||
|
'Data R': ct.FlowKind.Info,
|
||||||
|
},
|
||||||
|
input_sockets_optional={
|
||||||
|
'Expr L': True,
|
||||||
|
'Data L': True,
|
||||||
|
'Data R': True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow:
|
||||||
|
expr_l = input_sockets['Expr L']
|
||||||
|
data_l_info = input_sockets['Data L']
|
||||||
|
data_r_info = input_sockets['Data R']
|
||||||
|
|
||||||
|
has_expr_l = not ct.FlowSignal.check(expr_l)
|
||||||
|
has_data_l_info = not ct.FlowSignal.check(data_l_info)
|
||||||
|
has_data_r_info = not ct.FlowSignal.check(data_r_info)
|
||||||
|
|
||||||
|
# Info by Socket Set
|
||||||
|
## Data | Data
|
||||||
|
if has_data_l_info and has_data_r_info:
|
||||||
|
return data_r_info
|
||||||
|
|
||||||
|
## Expr | Data
|
||||||
|
if has_expr_l and has_data_r_info:
|
||||||
|
return data_r_info
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""Declares `TransformMathNode`."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import typing as typ
|
||||||
|
|
||||||
|
import bpy
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
|
from blender_maxwell.utils import bl_cache, logger
|
||||||
|
from blender_maxwell.utils import extra_sympy_units as spux
|
||||||
|
|
||||||
|
from .... import contracts as ct
|
||||||
|
from .... import sockets
|
||||||
|
from ... import base, events
|
||||||
|
|
||||||
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformMathNode(base.MaxwellSimNode):
|
||||||
|
r"""Applies a function to the array as a whole, with arbitrary results.
|
||||||
|
|
||||||
|
The shape, type, and interpretation of the input/output data is dynamically shown.
|
||||||
|
|
||||||
|
# Socket Sets
|
||||||
|
## Interpret
|
||||||
|
Reinterprets the `InfoFlow` of an array, **without changing it**.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
operation: Operation to apply to the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_type = ct.NodeType.TransformMath
|
||||||
|
bl_label = 'Transform Math'
|
||||||
|
|
||||||
|
input_sockets: typ.ClassVar = {
|
||||||
|
'Data': sockets.DataSocketDef(format='jax'),
|
||||||
|
}
|
||||||
|
input_socket_sets: typ.ClassVar = {
|
||||||
|
'Fourier': {},
|
||||||
|
'Affine': {},
|
||||||
|
'Convolve': {},
|
||||||
|
}
|
||||||
|
output_sockets: typ.ClassVar = {
|
||||||
|
'Data': sockets.DataSocketDef(format='jax'),
|
||||||
|
}
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Properties
|
||||||
|
####################
|
||||||
|
operation: enum.Enum = bl_cache.BLField(
|
||||||
|
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_operations(self) -> list[ct.BLEnumElement]:
|
||||||
|
if self.active_socket_set == 'Fourier': # noqa: SIM114
|
||||||
|
items = []
|
||||||
|
elif self.active_socket_set == 'Affine': # noqa: SIM114
|
||||||
|
items = []
|
||||||
|
elif self.active_socket_set == 'Convolve':
|
||||||
|
items = []
|
||||||
|
else:
|
||||||
|
msg = f'Active socket set {self.active_socket_set} is unknown'
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return [(*item, '', i) for i, item in enumerate(items)]
|
||||||
|
|
||||||
|
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
|
||||||
|
layout.prop(self, self.blfields['operation'], text='')
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Events
|
||||||
|
####################
|
||||||
|
@events.on_value_changed(
|
||||||
|
prop_name='active_socket_set',
|
||||||
|
)
|
||||||
|
def on_socket_set_changed(self):
|
||||||
|
self.operation = bl_cache.Signal.ResetEnumItems
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Compute: LazyValueFunc / Array
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.LazyValueFunc,
|
||||||
|
props={'active_socket_set', 'operation'},
|
||||||
|
input_sockets={'Data'},
|
||||||
|
input_socket_kinds={
|
||||||
|
'Data': ct.FlowKind.LazyValueFunc,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_data(self, props: dict, input_sockets: dict):
|
||||||
|
has_data = not ct.FlowSignal.check(input_sockets['Data'])
|
||||||
|
if not has_data or props['operation'] == 'NONE':
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
|
||||||
|
'Fourier': {},
|
||||||
|
'Affine': {},
|
||||||
|
'Convolve': {},
|
||||||
|
}[props['active_socket_set']][props['operation']]
|
||||||
|
|
||||||
|
# Compose w/Lazy Root Function Data
|
||||||
|
return input_sockets['Data'].compose_within(
|
||||||
|
mapping_func,
|
||||||
|
supports_jax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.Array,
|
||||||
|
output_sockets={'Data'},
|
||||||
|
output_socket_kinds={
|
||||||
|
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
|
||||||
|
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
|
||||||
|
params = output_sockets['Data'][ct.FlowKind.Params]
|
||||||
|
|
||||||
|
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
|
||||||
|
return ct.ArrayFlow(
|
||||||
|
values=lazy_value_func.func_jax(
|
||||||
|
*params.func_args, **params.func_kwargs
|
||||||
|
),
|
||||||
|
unit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Compute Auxiliary: Info / Params
|
||||||
|
####################
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.Info,
|
||||||
|
props={'active_socket_set', 'operation'},
|
||||||
|
input_sockets={'Data'},
|
||||||
|
input_socket_kinds={'Data': ct.FlowKind.Info},
|
||||||
|
)
|
||||||
|
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
|
||||||
|
info = input_sockets['Data']
|
||||||
|
if ct.FlowSignal.check(info):
|
||||||
|
return ct.FlowSignal.FlowPending
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
@events.computes_output_socket(
|
||||||
|
'Data',
|
||||||
|
kind=ct.FlowKind.Params,
|
||||||
|
input_sockets={'Data'},
|
||||||
|
input_socket_kinds={'Data': ct.FlowKind.Params},
|
||||||
|
)
|
||||||
|
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
|
||||||
|
return input_sockets['Data']
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# - Blender Registration
|
||||||
|
####################
|
||||||
|
BL_REGISTER = [
|
||||||
|
TransformMathNode,
|
||||||
|
]
|
||||||
|
BL_NODES = {ct.NodeType.TransformMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}
|
|
@ -61,32 +61,34 @@ class VizMode(enum.StrEnum):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
|
||||||
|
EMPTY = ()
|
||||||
|
Z = spux.MathType.Integer
|
||||||
|
R = spux.MathType.Real
|
||||||
|
VM = VizMode
|
||||||
|
|
||||||
valid_viz_modes = {
|
valid_viz_modes = {
|
||||||
((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D],
|
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
|
||||||
((spux.MathType.Integer), (spux.MathType.Real)): [
|
((Z), (None, R)): [
|
||||||
VizMode.Hist1D,
|
VM.Hist1D,
|
||||||
VizMode.BoxPlot1D,
|
VM.BoxPlot1D,
|
||||||
],
|
],
|
||||||
((spux.MathType.Real,), (spux.MathType.Real,)): [
|
((R,), (None, R)): [
|
||||||
VizMode.Curve2D,
|
VM.Curve2D,
|
||||||
VizMode.Points2D,
|
VM.Points2D,
|
||||||
VizMode.Bar,
|
VM.Bar,
|
||||||
],
|
],
|
||||||
((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [
|
((R, Z), (None, R)): [
|
||||||
VizMode.Curves2D,
|
VM.Curves2D,
|
||||||
VizMode.FilledCurves2D,
|
VM.FilledCurves2D,
|
||||||
],
|
],
|
||||||
((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [
|
((R, R), (None, R)): [
|
||||||
VizMode.Heatmap2D,
|
VM.Heatmap2D,
|
||||||
],
|
],
|
||||||
(
|
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
|
||||||
(spux.MathType.Real, spux.MathType.Real, spux.MathType.Real),
|
|
||||||
(spux.MathType.Real,),
|
|
||||||
): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D],
|
|
||||||
}.get(
|
}.get(
|
||||||
(
|
(
|
||||||
tuple(info.dim_mathtypes.values()),
|
tuple(info.dim_mathtypes.values()),
|
||||||
tuple(info.output_mathtypes.values()),
|
(info.output_shape, info.output_mathtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,10 +163,10 @@ class VizTarget(enum.StrEnum):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_name(value: typ.Self) -> str:
|
def to_name(value: typ.Self) -> str:
|
||||||
return {
|
return {
|
||||||
VizTarget.Plot2D: 'Image (Plot)',
|
VizTarget.Plot2D: 'Plot',
|
||||||
VizTarget.Pixels: 'Image (Pixels)',
|
VizTarget.Pixels: 'Pixels',
|
||||||
VizTarget.PixelsPlane: 'Image (Plane)',
|
VizTarget.PixelsPlane: 'Image Plane',
|
||||||
VizTarget.Voxels: '3D Field',
|
VizTarget.Voxels: 'Voxels',
|
||||||
}[value]
|
}[value]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1036,6 +1036,12 @@ class MaxwellSimNode(bpy.types.Node):
|
||||||
# Generate New Instance ID
|
# Generate New Instance ID
|
||||||
self.reset_instance_id()
|
self.reset_instance_id()
|
||||||
|
|
||||||
|
# Generate New Instance ID for Sockets
|
||||||
|
## Sockets can't do this themselves.
|
||||||
|
for bl_sockets in [self.inputs, self.outputs]:
|
||||||
|
for bl_socket in bl_sockets:
|
||||||
|
bl_socket.reset_instance_id()
|
||||||
|
|
||||||
# Generate New Sim Node Name
|
# Generate New Sim Node Name
|
||||||
## Blender will automatically add .001 so that `self.name` is unique.
|
## Blender will automatically add .001 so that `self.name` is unique.
|
||||||
self.sim_node_name = self.name
|
self.sim_node_name = self.name
|
||||||
|
|
|
@ -14,7 +14,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
|
||||||
bl_label = 'Scientific Constant'
|
bl_label = 'Scientific Constant'
|
||||||
|
|
||||||
output_sockets: typ.ClassVar = {
|
output_sockets: typ.ClassVar = {
|
||||||
'Value': sockets.AnySocketDef(),
|
'Value': sockets.ExprSocketDef(),
|
||||||
}
|
}
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -913,10 +913,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
|
||||||
col = layout.column()
|
col = layout.column()
|
||||||
row = col.row()
|
row = col.row()
|
||||||
row.alignment = 'RIGHT'
|
row.alignment = 'RIGHT'
|
||||||
if self.is_linked:
|
|
||||||
self.draw_output_label_row(row, text)
|
self.draw_output_label_row(row, text)
|
||||||
else:
|
|
||||||
row.label(text=text)
|
|
||||||
|
|
||||||
# Draw FlowKind.Info related Information
|
# Draw FlowKind.Info related Information
|
||||||
if self.use_info_draw:
|
if self.use_info_draw:
|
||||||
|
|
|
@ -12,6 +12,10 @@ from .. import base
|
||||||
log = logger.get(__name__)
|
log = logger.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def unicode_superscript(n):
|
||||||
|
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
|
||||||
|
|
||||||
|
|
||||||
class DataInfoColumn(enum.StrEnum):
|
class DataInfoColumn(enum.StrEnum):
|
||||||
Length = enum.auto()
|
Length = enum.auto()
|
||||||
MathType = enum.auto()
|
MathType = enum.auto()
|
||||||
|
@ -49,11 +53,11 @@ class DataBLSocket(base.MaxwellSimSocket):
|
||||||
## TODO: typ.Literal['xarray', 'jax']
|
## TODO: typ.Literal['xarray', 'jax']
|
||||||
|
|
||||||
show_info_columns: bool = bl_cache.BLField(
|
show_info_columns: bool = bl_cache.BLField(
|
||||||
False,
|
True,
|
||||||
prop_ui=True,
|
prop_ui=True,
|
||||||
)
|
)
|
||||||
info_columns: DataInfoColumn = bl_cache.BLField(
|
info_columns: DataInfoColumn = bl_cache.BLField(
|
||||||
{DataInfoColumn.MathType, DataInfoColumn.Length}, prop_ui=True, enum_many=True
|
{DataInfoColumn.MathType, DataInfoColumn.Unit}, prop_ui=True, enum_many=True
|
||||||
)
|
)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -71,8 +75,10 @@ class DataBLSocket(base.MaxwellSimSocket):
|
||||||
# - UI
|
# - UI
|
||||||
####################
|
####################
|
||||||
def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None:
|
def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None:
|
||||||
if self.format == 'jax':
|
|
||||||
row.label(text=text)
|
row.label(text=text)
|
||||||
|
|
||||||
|
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||||
|
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
|
||||||
row.prop(self, self.blfields['info_columns'])
|
row.prop(self, self.blfields['info_columns'])
|
||||||
row.prop(
|
row.prop(
|
||||||
self,
|
self,
|
||||||
|
@ -83,7 +89,8 @@ class DataBLSocket(base.MaxwellSimSocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None:
|
def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None:
|
||||||
if self.format == 'jax':
|
info = self.compute_data(kind=ct.FlowKind.Info)
|
||||||
|
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
|
||||||
row.prop(
|
row.prop(
|
||||||
self,
|
self,
|
||||||
self.blfields['show_info_columns'],
|
self.blfields['show_info_columns'],
|
||||||
|
@ -92,6 +99,7 @@ class DataBLSocket(base.MaxwellSimSocket):
|
||||||
icon=ct.Icon.ToggleSocketInfo,
|
icon=ct.Icon.ToggleSocketInfo,
|
||||||
)
|
)
|
||||||
row.prop(self, self.blfields['info_columns'])
|
row.prop(self, self.blfields['info_columns'])
|
||||||
|
|
||||||
row.label(text=text)
|
row.label(text=text)
|
||||||
|
|
||||||
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
|
||||||
|
@ -118,16 +126,27 @@ class DataBLSocket(base.MaxwellSimSocket):
|
||||||
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
grid.label(text=spux.sp_to_str(dim_idx.unit))
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
for output_name in info.output_names:
|
grid.label(text=info.output_name)
|
||||||
grid.label(text=output_name)
|
|
||||||
if DataInfoColumn.Length in self.info_columns:
|
if DataInfoColumn.Length in self.info_columns:
|
||||||
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
grid.label(text='', icon=ct.Icon.DataSocketOutput)
|
||||||
if DataInfoColumn.MathType in self.info_columns:
|
if DataInfoColumn.MathType in self.info_columns:
|
||||||
grid.label(
|
grid.label(
|
||||||
text=spux.MathType.to_str(info.output_mathtypes[output_name])
|
text=(
|
||||||
|
spux.MathType.to_str(info.output_mathtype)
|
||||||
|
+ (
|
||||||
|
'ˣ'.join(
|
||||||
|
[
|
||||||
|
unicode_superscript(out_axis)
|
||||||
|
for out_axis in info.output_shape
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if info.output_shape
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if DataInfoColumn.Unit in self.info_columns:
|
if DataInfoColumn.Unit in self.info_columns:
|
||||||
grid.label(text=spux.sp_to_str(info.output_units[output_name]))
|
grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -137,9 +156,11 @@ class DataSocketDef(base.SocketDef):
|
||||||
socket_type: ct.SocketType = ct.SocketType.Data
|
socket_type: ct.SocketType = ct.SocketType.Data
|
||||||
|
|
||||||
format: typ.Literal['xarray', 'jax', 'monitor_data']
|
format: typ.Literal['xarray', 'jax', 'monitor_data']
|
||||||
|
default_show_info_columns: bool = True
|
||||||
|
|
||||||
def init(self, bl_socket: DataBLSocket) -> None:
|
def init(self, bl_socket: DataBLSocket) -> None:
|
||||||
bl_socket.format = self.format
|
bl_socket.format = self.format
|
||||||
|
bl_socket.default_show_info_columns = self.default_show_info_columns
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
@ -86,9 +86,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
|
||||||
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
|
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
|
||||||
return ct.LazyValueFuncFlow(
|
return ct.LazyValueFuncFlow(
|
||||||
func=sp.lambdify(self.symbols, self.value, 'jax'),
|
func=sp.lambdify(self.symbols, self.value, 'jax'),
|
||||||
func_args=[
|
func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols],
|
||||||
(sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols
|
|
||||||
],
|
|
||||||
supports_jax=True,
|
supports_jax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,18 @@ class MathType(enum.StrEnum):
|
||||||
Real = enum.auto()
|
Real = enum.auto()
|
||||||
Complex = enum.auto()
|
Complex = enum.auto()
|
||||||
|
|
||||||
|
def combine(*mathtypes: list[typ.Self]) -> typ.Self:
|
||||||
|
if MathType.Complex in mathtypes:
|
||||||
|
return MathType.Complex
|
||||||
|
elif MathType.Real in mathtypes:
|
||||||
|
return MathType.Real
|
||||||
|
elif MathType.Rational in mathtypes:
|
||||||
|
return MathType.Rational
|
||||||
|
elif MathType.Integer in mathtypes:
|
||||||
|
return MathType.Integer
|
||||||
|
elif MathType.Bool in mathtypes:
|
||||||
|
return MathType.Bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_expr(sp_obj: SympyType) -> type:
|
def from_expr(sp_obj: SympyType) -> type:
|
||||||
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
|
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
|
||||||
|
@ -595,6 +607,21 @@ ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
|
||||||
)
|
)
|
||||||
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
|
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
|
||||||
|
|
||||||
|
# Number
|
||||||
|
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
|
||||||
|
allow_variables=False,
|
||||||
|
allow_units=True,
|
||||||
|
allowed_sets={'integer', 'rational', 'real'},
|
||||||
|
allowed_structures={'scalar'},
|
||||||
|
)
|
||||||
|
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
|
||||||
|
allow_variables=False,
|
||||||
|
allow_units=True,
|
||||||
|
allowed_sets={'integer', 'rational', 'real', 'complex'},
|
||||||
|
allowed_structures={'scalar'},
|
||||||
|
)
|
||||||
|
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
|
||||||
|
|
||||||
# Vector
|
# Vector
|
||||||
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
|
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
|
||||||
allow_variables=False,
|
allow_variables=False,
|
||||||
|
|
|
@ -117,8 +117,8 @@ def rgba_image_from_2d_map(
|
||||||
def plot_hist_1d(
|
def plot_hist_1d(
|
||||||
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
|
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
|
||||||
) -> None:
|
) -> None:
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
ax.hist(data, bins=30, alpha=0.75)
|
ax.hist(data, bins=30, alpha=0.75)
|
||||||
ax.set_title('Histogram')
|
ax.set_title('Histogram')
|
||||||
|
@ -130,8 +130,8 @@ def plot_box_plot_1d(
|
||||||
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
|
||||||
) -> None:
|
) -> None:
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
ax.boxplot(data)
|
ax.boxplot(data)
|
||||||
ax.set_title('Box Plot')
|
ax.set_title('Box Plot')
|
||||||
|
@ -147,8 +147,8 @@ def plot_curve_2d(
|
||||||
|
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
x_unit = info.dim_units[x_name]
|
x_unit = info.dim_units[x_name]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
times.append(time.perf_counter() - times[0])
|
times.append(time.perf_counter() - times[0])
|
||||||
ax.plot(info.dim_idx_arrays[0], data)
|
ax.plot(info.dim_idx_arrays[0], data)
|
||||||
|
@ -167,8 +167,8 @@ def plot_points_2d(
|
||||||
) -> None:
|
) -> None:
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
x_unit = info.dim_units[x_name]
|
x_unit = info.dim_units[x_name]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
|
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
|
||||||
ax.set_title('2D Points')
|
ax.set_title('2D Points')
|
||||||
|
@ -179,8 +179,8 @@ def plot_points_2d(
|
||||||
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
x_unit = info.dim_units[x_name]
|
x_unit = info.dim_units[x_name]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
|
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
|
||||||
ax.set_title('2D Bar')
|
ax.set_title('2D Bar')
|
||||||
|
@ -194,8 +194,8 @@ def plot_curves_2d(
|
||||||
) -> None:
|
) -> None:
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
x_unit = info.dim_units[x_name]
|
x_unit = info.dim_units[x_name]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
for category in range(data.shape[1]):
|
for category in range(data.shape[1]):
|
||||||
ax.plot(data[:, 0], data[:, 1])
|
ax.plot(data[:, 0], data[:, 1])
|
||||||
|
@ -211,8 +211,8 @@ def plot_filled_curves_2d(
|
||||||
) -> None:
|
) -> None:
|
||||||
x_name = info.dim_names[0]
|
x_name = info.dim_names[0]
|
||||||
x_unit = info.dim_units[x_name]
|
x_unit = info.dim_units[x_name]
|
||||||
y_name = info.output_names[0]
|
y_name = info.output_name
|
||||||
y_unit = info.output_units[y_name]
|
y_unit = info.output_unit
|
||||||
|
|
||||||
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
|
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
|
||||||
ax.set_title('2D Curves')
|
ax.set_title('2D Curves')
|
||||||
|
|
Loading…
Reference in New Issue