feat: Implemented operate math node.

main
Sofus Albert Høgsbro Rose 2024-04-26 17:22:55 +02:00
parent 7d944a704e
commit b2a7eefb45
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
18 changed files with 1170 additions and 390 deletions

View File

@ -14,6 +14,8 @@
- [x] Extract
- [x] Viz
- [x] Math / Map Math
- [ ] Remove "By x" socket set let socket sets only be "Function"/"Expr"; then add a dynamic enum underneath to select "By x" based on data support.
- [ ] Filter the operations based on data support, ex. use positive-definiteness to guide cholesky.
- [x] Math / Filter Math
- [ ] Math / Reduce Math
- [ ] Math / Operate Math
@ -34,8 +36,6 @@
- [x] Constants / Blender Constant
- [ ] Web / Tidy3D Web Importer
- [ ] Change to output only a `FilePath`, which can be plugged into a Tidy3D File Importer.
- [ ] Implement caching, such that the file will only download if the file doesn't already exist.
- [ ] Have a visual indicator for the current download status, with a manual re-download button.
- [x] File Import / Material Import

View File

@ -68,6 +68,9 @@ class FlowKind(enum.StrEnum):
if kind == cls.LazyArrayRange:
return value.rescale_to_unit(unit_system[socket_type])
if kind == cls.Params:
return value.rescale_to_unit(unit_system[socket_type])
msg = 'Tried to scale unknown kind'
raise ValueError(msg)
@ -322,6 +325,28 @@ class LazyValueFuncFlow:
supports_jax: bool = False
supports_numba: bool = False
# Merging
def __or__(
self,
other: typ.Self,
):
return LazyValueFuncFlow(
func=lambda *args, **kwargs: (
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
other.func(
*list(args[len(self.func_args) :]),
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
supports_numba=self.supports_numba and other.supports_numba,
)
# Composition
def compose_within(
self,
@ -691,10 +716,21 @@ class ParamsFlow:
func_args: list[typ.Any] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
def __or__(
self,
other: typ.Self,
):
return ParamsFlow(
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
)
def compose_within(
self,
enclosing_func_args: list[tuple[type]] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
enclosing_func_arg_units: dict[str, type] = MappingProxyType({}),
enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}),
) -> typ.Self:
return ParamsFlow(
func_args=self.func_args + list(enclosing_func_args),
@ -718,26 +754,133 @@ class InfoFlow:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_mathtypes(self) -> dict[str, int]:
def dim_mathtypes(self) -> dict[str, spux.MathType]:
return {
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
}
@functools.cached_property
def dim_units(self) -> dict[str, int]:
def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property
def dim_idx_arrays(self) -> list[ArrayFlow]:
def dim_idx_arrays(self) -> list[jax.Array]:
return [
dim_idx.realize().values
if isinstance(dim_idx, LazyArrayRangeFlow)
else dim_idx.values
for dim_idx in self.dim_idx.values()
]
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
# Output Information
output_names: list[str] = dataclasses.field(default_factory=list)
output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict)
output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict)
output_name: str = dataclasses.field(default_factory=list)
output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
# Pinned Dimension Information
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict
)
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict
)
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
####################
# - Methods
####################
def delete_dimension(self, dim_name: str) -> typ.Self:
"""Delete a dimension."""
return InfoFlow(
# Dimensions
dim_names=[
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
],
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
"""Delete a dimension."""
# Compute Swapped Dimension Name List
def name_swapper(dim_name):
return (
dim_name
if dim_name not in [dim_0_name, dim_1_name]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
)
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
# Compute Info
return InfoFlow(
# Dimensions
dim_names=dim_names,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
"""Set the MathType of a particular output name."""
return InfoFlow(
dim_names=self.dim_names,
dim_idx=self.dim_idx,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
)
def collapse_output(
self,
collapsed_name: str,
collapsed_mathtype: spux.MathType,
collapsed_unit: spux.Unit,
) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx=self.dim_idx,
output_name=collapsed_name,
output_shape=None,
output_mathtype=collapsed_mathtype,
output_unit=collapsed_unit,
)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names[:-1],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
},
# Outputs
output_name=self.output_name,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)

View File

@ -15,6 +15,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
FilterMath = enum.auto()
ReduceMath = enum.auto()
OperateMath = enum.auto()
TransformMath = enum.auto()
# Inputs
WaveConstant = enum.auto()

View File

@ -260,7 +260,9 @@ SOCKET_UNITS = {
}
def unit_to_socket_type(unit: spux.Unit) -> ST:
def unit_to_socket_type(
unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None
) -> ST:
"""Returns a SocketType that accepts the given unit.
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
@ -269,6 +271,14 @@ def unit_to_socket_type(unit: spux.Unit) -> ST:
Returns:
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
"""
if unit is None and fallback_mathtype is not None:
return {
spux.MathType.Integer: ST.IntegerNumber,
spux.MathType.Rational: ST.RationalNumber,
spux.MathType.Real: ST.RealNumber,
spux.MathType.Complex: ST.ComplexNumber,
}[fallback_mathtype]
for socket_type, _units in SOCKET_UNITS.items():
if unit in _units['values'].values():
return socket_type

View File

@ -215,6 +215,15 @@ class ExtractDataNode(base.MaxwellSimNode):
####################
# - UI
####################
def draw_label(self):
has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_components is not None
if has_sim_data or has_monitor_data:
return f'Extract: {self.extract_filter}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Draw node properties in the node.
@ -223,43 +232,6 @@ class ExtractDataNode(base.MaxwellSimNode):
"""
col.prop(self, self.blfields['extract_filter'], text='')
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Draw dynamic information in the node, for user consideration.
Parameters:
col: UI target for drawing.
"""
has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_components is not None
if has_sim_data or has_monitor_data:
# Header
row = col.row()
row.alignment = 'CENTER'
if has_sim_data:
row.label(text=f'{len(self.sim_data_monitor_nametype)} Monitors')
elif has_monitor_data:
row.label(text=f'{self.monitor_data_type} Monitor Data')
# Monitor Data Contents
## TODO: More compact double-split
## TODO: Output shape data.
## TODO: Local ENUM_MANY tabs for visible column selection?
row = col.row()
box = row.box()
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
if has_sim_data:
for (
monitor_name,
monitor_type,
) in self.sim_data_monitor_nametype.items():
grid.label(text=monitor_name)
grid.label(text=monitor_type.replace('Data', ''))
elif has_monitor_data:
for component_name in self.monitor_data_components:
grid.label(text=component_name)
grid.label(text=self.monitor_data_type)
####################
# - Events
####################
@ -416,9 +388,8 @@ class ExtractDataNode(base.MaxwellSimNode):
else:
return ct.FlowSignal.FlowPending
info_output_names = {
'output_names': [props['extract_filter']],
}
info_output_name = props['extract_filter']
info_output_shape = None
# Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian
@ -442,13 +413,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
output_units={
props['extract_filter']: spu.volt / spu.micrometer
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer
if props['monitor_data_type'] == 'Field'
else None
},
),
)
## XYZT: FieldTime
@ -468,17 +440,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Complex},
output_units={
props['extract_filter']: (
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Complex,
output_unit=(
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
if props['monitor_data_type'] == 'Field'
else None
},
),
)
## F: Flux
@ -492,9 +461,10 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={props['extract_filter']: spu.watt},
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## T: FluxTime
@ -508,9 +478,10 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={props['extract_filter']: spu.watt},
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
)
## RThetaPhiF: FieldProjectionAngle
@ -537,15 +508,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
),
)
## UxUyRF: FieldProjectionKSpace
@ -570,15 +540,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
),
)
## OrderxOrderyF: Diffraction
@ -600,15 +569,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True,
),
},
**info_output_names,
output_mathtypes={props['extract_filter']: spux.MathType.Real},
output_units={
props['extract_filter']: (
output_name=props['extract_filter'],
output_shape=None,
output_mathtype=spux.MathType.Real,
output_unit=(
spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
},
),
)
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'

View File

@ -1,14 +1,16 @@
from . import filter_math, map_math, operate_math, reduce_math
from . import filter_math, map_math, operate_math, reduce_math, transform_math
BL_REGISTER = [
*map_math.BL_REGISTER,
*filter_math.BL_REGISTER,
*reduce_math.BL_REGISTER,
*operate_math.BL_REGISTER,
*transform_math.BL_REGISTER,
]
BL_NODES = {
**map_math.BL_NODES,
**filter_math.BL_NODES,
**reduce_math.BL_NODES,
**operate_math.BL_NODES,
**transform_math.BL_NODES,
}

View File

@ -1,3 +1,5 @@
"""Declares `FilterMathNode`."""
import enum
import typing as typ
@ -15,11 +17,21 @@ log = logger.get(__name__)
class FilterMathNode(base.MaxwellSimNode):
"""Reduces the dimensionality of data.
r"""Applies a function that operates on the shape of the array.
The shape, type, and interpretation of the input/output data is dynamically shown.
# Socket Sets
## Dimensions
Alter the dimensions of the array.
## Interpret
Only alter the interpretation of the array data, which guides what it can be used for.
These operations are **zero cost**, since the data itself is untouched.
Attributes:
operation: Operation to apply to the input.
dim: Dims to use when filtering data
"""
node_type = ct.NodeType.FilterMath
@ -29,8 +41,8 @@ class FilterMathNode(base.MaxwellSimNode):
'Data': sockets.DataSocketDef(format='jax'),
}
input_socket_sets: typ.ClassVar = {
'By Dim': {},
'By Dim Value': {},
'Interpret': {},
'Dimensions': {},
}
output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'),
@ -43,10 +55,17 @@ class FilterMathNode(base.MaxwellSimNode):
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
)
dim: enum.Enum = bl_cache.BLField(
# Dimension Selection
dim_0: enum.Enum = bl_cache.BLField(
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
)
dim_1: enum.Enum = bl_cache.BLField(
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
)
####################
# - Computed
####################
@property
def data_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Data', kind=ct.FlowKind.Info)
@ -60,87 +79,119 @@ class FilterMathNode(base.MaxwellSimNode):
####################
def search_operations(self) -> list[tuple[str, str, str]]:
items = []
if self.active_socket_set == 'By Dim':
if self.active_socket_set == 'Interpret':
items += [
('SQUEEZE', 'del a | #=1', 'Squeeze'),
('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'),
('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'),
]
if self.active_socket_set == 'By Dim Value':
elif self.active_socket_set == 'Dimensions':
items += [
('FIX', 'del a | i≈v', 'Fix Coordinate'),
('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'),
(
'PIN',
'pinₐ ≈v',
'Remove a len(n) dimension by selecting an index',
),
('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'),
]
return [(*item, '', i) for i, item in enumerate(items)]
####################
# - Dim Search
# - Dimensions Search
####################
def search_dims(self) -> list[ct.BLEnumElement]:
if self.data_info is not None:
dims = [
(dim_name, dim_name, dim_name, '', i)
for i, dim_name in enumerate(self.data_info.dim_names)
]
# Squeeze: Dimension Must Have Length=1
## We must also correct the "NUMBER" of the enum.
if self.operation == 'SQUEEZE':
filtered_dims = [
dim for dim in dims if self.data_info.dim_lens[dim[0]] == 1
]
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
return dims
if self.data_info is None:
return []
if self.operation == 'PIN_LEN_ONE':
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
for dim_name in self.data_info.dim_names
if self.data_info.dim_lens[dim_name] == 1
]
elif self.operation in ['PIN', 'SWAP']:
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}"')
for dim_name in self.data_info.dim_names
]
else:
return []
return [(*dim, '', i) for i, dim in enumerate(dims)]
####################
# - UI
####################
def draw_label(self):
labels = {
'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)',
'PIN': lambda: f'Filter: Pin {self.dim_0}',
'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
'DIM_TO_VEC': lambda: 'Filter: -> Vector',
'DIMS_TO_MAT': lambda: 'Filter: -> Matrix',
}
if (label := labels.get(self.operation)) is not None:
return label()
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
if self.data_info is not None and self.data_info.dim_names:
layout.prop(self, self.blfields['dim'], text='')
if self.active_socket_set == 'Dimensions':
if self.operation in ['PIN_LEN_ONE', 'PIN']:
layout.prop(self, self.blfields['dim_0'], text='')
if self.operation == 'SWAP':
row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='')
####################
# - Events
####################
@events.on_value_changed(
socket_name='Data',
prop_name='active_socket_set',
run_on_init=True,
input_sockets={'Data'},
)
def on_any_change(self, input_sockets: dict):
if all(
not ct.FlowSignal.check_single(
input_socket_value, ct.FlowSignal.FlowPending
)
for input_socket_value in input_sockets.values()
):
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
self.dim = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
socket_name='Data',
prop_name={'active_socket_set', 'operation'},
run_on_init=True,
# Loaded
props={'operation'},
)
def on_any_change(self, props: dict) -> None:
self.dim_0 = bl_cache.Signal.ResetEnumItems
self.dim_1 = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
socket_name='Data',
prop_name='dim',
prop_name={'dim_0', 'dim_1', 'operation'},
## run_on_init: Implicitly triggered.
props={'active_socket_set', 'dim'},
props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
)
def on_dim_change(self, props: dict, input_sockets: dict):
if input_sockets['Data'] == ct.FlowSignal.FlowPending:
has_data = not ct.FlowSignal.check(input_sockets['Data'])
if not has_data:
return
# Add/Remove Input Socket "Value"
if (
not ct.Flowsignal.check(input_sockets['Data'])
and props['active_socket_set'] == 'By Dim Value'
and props['dim'] != 'NONE'
):
# "Dimensions"|"PIN": Add/Remove Input Socket
if props['operation'] == 'PIN' and props['dim_0'] != 'NONE':
# Get Current and Wanted Socket Defs
current_bl_socket = self.loose_input_sockets.get('Value')
wanted_socket_def = sockets.SOCKET_DEFS[
ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit)
ct.unit_to_socket_type(
input_sockets['Data'].dim_idx[props['dim_0']].unit
)
]
# Determine Whether to Declare New Loose Input SOcket
@ -151,7 +202,7 @@ class FilterMathNode(base.MaxwellSimNode):
):
self.loose_input_sockets = {
'Value': wanted_socket_def(),
} ## TODO: Can we do the boilerplate in base.py?
}
elif self.loose_input_sockets:
self.loose_input_sockets = {}
@ -161,40 +212,51 @@ class FilterMathNode(base.MaxwellSimNode):
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
props={'active_socket_set', 'operation', 'dim'},
props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'},
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
)
def compute_data(self, props: dict, input_sockets: dict):
# Retrieve Inputs
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
info = input_sockets['Data'][ct.FlowKind.Info]
# Check Flow
if (
any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func])
or props['operation'] == 'NONE'
):
if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]):
return ct.FlowSignal.FlowPending
# Compute Bound/Free Parameters
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else []
axis = info.dim_names.index(props['dim'])
# Compute Function Arguments
operation = props['operation']
if operation == 'NONE':
return ct.FlowSignal.FlowPending
# Select Function
filter_func: typ.Callable[[jax.Array], jax.Array] = {
'By Dim': {'SQUEEZE': lambda data: jnp.squeeze(data, axis)},
'By Dim Value': {
'FIX': lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis
)
},
}[props['active_socket_set']][props['operation']]
## Dimension(s)
dim_0 = props['dim_0']
dim_1 = props['dim_1']
if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
return ct.FlowSignal.FlowPending
if operation == 'SWAP' and dim_1 == 'NONE':
return ct.FlowSignal.FlowPending
## Axis/Axes
axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None
axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None
# Compose Output Function
filter_func = {
# Dimensions
'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0),
'PIN': lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis_0
),
'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1),
# Interpret
'DIM_TO_VEC': lambda data: data,
'DIMS_TO_MAT': lambda data: data,
}[props['operation']]
# Compose Function for Output
return lazy_value_func.compose_within(
filter_func,
enclosing_func_args=func_args,
enclosing_func_args=[int] if operation == 'PIN' else [],
supports_jax=True,
)
@ -207,7 +269,6 @@ class FilterMathNode(base.MaxwellSimNode):
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
# Retrieve Inputs
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
@ -215,57 +276,54 @@ class FilterMathNode(base.MaxwellSimNode):
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
return ct.FlowSignal.FlowPending
# Compute Array
return ct.ArrayFlow(
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs),
unit=None, ## TODO: Unit Propagation
unit=None,
)
####################
# - Compute Auxiliary: Info / Params
# - Compute Auxiliary: Info
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
props={'active_socket_set', 'dim', 'operation'},
props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
)
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
# Retrieve Inputs
info = input_sockets['Data']
# Check Flow
if ct.FlowSignal.check(info) or props['dim'] == 'NONE':
if ct.FlowSignal.check(info):
return ct.FlowSignal.FlowPending
# Compute Information
## Compute Info w/By-Operation Change to Dimensions
axis = info.dim_names.index(props['dim'])
# Collect Information
dim_0 = props['dim_0']
dim_1 = props['dim_1']
if (props['active_socket_set'], props['operation']) in [
('By Dim', 'SQUEEZE'),
('By Dim Value', 'FIX'),
] and info.dim_names:
return ct.InfoFlow(
dim_names=info.dim_names[:axis] + info.dim_names[axis + 1 :],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in info.dim_idx.items()
if dim_name != props['dim']
},
output_names=info.output_names,
output_mathtypes=info.output_mathtypes,
output_units=info.output_units,
)
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
return ct.FlowSignal.FlowPending
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
return ct.FlowSignal.FlowPending
msg = f'Active socket set {props["active_socket_set"]} and operation {props["operation"]} don\'t have an InfoFlow defined'
raise RuntimeError(msg)
return {
# Dimensions
'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0),
'PIN': lambda: info.delete_dimension(dim_0),
'SWAP': lambda: info.swap_dimensions(dim_0, dim_1),
# Interpret
'DIM_TO_VEC': lambda: info.shift_last_input,
'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input,
}[props['operation']]()
####################
# - Compute Auxiliary: Info
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Params,
props={'active_socket_set', 'dim', 'operation'},
props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data', 'Value'},
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Value': True},
@ -273,35 +331,33 @@ class FilterMathNode(base.MaxwellSimNode):
def compute_composed_params(
self, props: dict, input_sockets: dict
) -> ct.ParamsFlow:
# Retrieve Inputs
info = input_sockets['Data'][ct.FlowKind.Info]
params = input_sockets['Data'][ct.FlowKind.Params]
# Check Flow
if any(ct.FlowSignal.check(inp) for inp in [info, params]):
return ct.FlowSignal.FlowPending
# Compute Composed Parameters
## -> Only operations that add parameters.
## -> A dimension must be selected.
## -> There must be an input value.
if (
(props['active_socket_set'], props['operation'])
in [
('By Dim Value', 'FIX'),
]
and props['dim'] != 'NONE'
and not ct.FlowSignal.check(input_sockets['Value'])
):
# Compute IDX Corresponding to Coordinate Value
## -> Each dimension declares a unit-aware real number at each index.
## -> "Value" is a unit-aware real number from loose input socket.
## -> This finds the dimensional index closest to "Value".
## Total Effect: Indexing by a unit-aware real number.
nearest_idx_to_value = info.dim_idx[props['dim']].nearest_idx_of(
# Collect Information
## Dimensions
dim_0 = props['dim_0']
dim_1 = props['dim_1']
if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
return ct.FlowSignal.FlowPending
if props['operation'] == 'SWAP' and dim_1 == 'NONE':
return ct.FlowSignal.FlowPending
## Pinned Value
pinned_value = input_sockets['Value']
has_pinned_value = not ct.FlowSignal.check(pinned_value)
if props['operation'] == 'PIN' and has_pinned_value:
# Compute IDX Corresponding to Dimension Index
nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
input_sockets['Value'], require_sorted=True
)
# Compose Parameters
return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
return params

View File

@ -138,6 +138,7 @@ class MapMathNode(base.MaxwellSimNode):
('SQ', '', 'v^2 (by el)'),
('SQRT', '√v', 'sqrt(v) (by el)'),
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
None,
# Trigonometry
('COS', 'cos v', 'cos(v) (by el)'),
('SIN', 'sin v', 'sin(v) (by el)'),
@ -148,6 +149,7 @@ class MapMathNode(base.MaxwellSimNode):
]
elif self.active_socket_set in 'By Vector':
items = [
# Vector -> Number
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
]
elif self.active_socket_set == 'By Matrix':
@ -157,13 +159,16 @@ class MapMathNode(base.MaxwellSimNode):
('COND', 'κ(V)', 'cond(V) (by Mat)'),
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
('RANK', 'rank V', 'rank(V) (by Mat)'),
None,
# Matrix -> Array
('DIAG', 'diag V', 'diag(V) (by Mat)'),
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
None,
# Matrix -> Matrix
('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
('TRA', 'Vt', 'V^T (by Mat)'),
None,
# Matrix -> Matrices
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
@ -175,7 +180,9 @@ class MapMathNode(base.MaxwellSimNode):
msg = f'Active socket set {self.active_socket_set} is unknown'
raise RuntimeError(msg)
return [(*item, '', i) for i, item in enumerate(items)]
return [
(*item, '', i) if item is not None else None for i, item in enumerate(items)
]
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
@ -185,8 +192,9 @@ class MapMathNode(base.MaxwellSimNode):
####################
@events.on_value_changed(
prop_name='active_socket_set',
run_on_init=True,
)
def on_operation_changed(self):
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
####################
@ -204,11 +212,14 @@ class MapMathNode(base.MaxwellSimNode):
input_sockets_optional={'Mapper': True},
)
def compute_data(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data'])
if (
ct.FlowSignal.check(input_sockets['Data']) or props['operation'] == 'NONE'
) or (
not has_data
or props['operation'] == 'NONE'
or (
props['active_socket_set'] == 'Expr'
and ct.FlowSignal.check(input_sockets['Mapper'])
)
):
return ct.FlowSignal.FlowPending
@ -238,14 +249,14 @@ class MapMathNode(base.MaxwellSimNode):
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
'RANK': lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec
'DIAG': lambda data: jnp.diag(data),
'EIG_VALS': lambda data: jnp.eigvals(data),
'SVD_VALS': lambda data: jnp.svdvals(data),
'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
'EIG_VALS': lambda data: jnp.linalg.eigvals(data),
'SVD_VALS': lambda data: jnp.linalg.svdvals(data),
# Matrix -> Matrix
'INV': lambda data: jnp.inv(data),
'INV': lambda data: jnp.linalg.inv(data),
'TRA': lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices
'QR': lambda data: jnp.inv(data),
'QR': lambda data: jnp.linalg.qr(data),
'CHOL': lambda data: jnp.linalg.cholesky(data),
'SVD': lambda data: jnp.linalg.svd(data),
},
@ -298,28 +309,53 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending
# Complex -> Real
if props['active_socket_set'] == 'By Element':
if props['operation'] in [
if props['active_socket_set'] == 'By Element' and props['operation'] in [
'REAL',
'IMAG',
'ABS',
]:
return ct.InfoFlow(
dim_names=info.dim_names,
dim_idx=info.dim_idx,
output_names=info.output_names,
output_mathtypes={
output_name: (
spux.MathType.Real
if output_mathtype == spux.MathType.Complex
else output_mathtype
return info.set_output_mathtype(spux.MathType.Real)
if props['active_socket_set'] == 'By Vector' and props['operation'] in [
'NORM_2'
]:
return {
'NORM_2': lambda: info.collapse_output(
collapsed_name=f'||{info.output_name}||₂',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
)
for output_name, output_mathtype in info.output_mathtypes.items()
},
output_units=info.output_units,
)
if props['active_socket_set'] == 'By Vector':
pass
}[props['operation']]()
if props['active_socket_set'] == 'By Matrix' and props['operation'] in [
'DET',
'COND',
'NORM_FRO',
'RANK',
]:
return {
'DET': lambda: info.collapse_output(
collapsed_name=f'det {info.output_name}',
collapsed_mathtype=info.output_mathtype,
collapsed_unit=info.output_unit,
),
'COND': lambda: info.collapse_output(
collapsed_name=f'κ({info.output_name})',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=None,
),
'NORM_FRO': lambda: info.collapse_output(
collapsed_name=f'||({info.output_name}||_F',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
),
'RANK': lambda: info.collapse_output(
collapsed_name=f'rank {info.output_name}',
collapsed_mathtype=spux.MathType.Integer,
collapsed_unit=None,
),
}[props['operation']]()
return info
@events.computes_output_socket(

View File

@ -1,9 +1,12 @@
import enum
import typing as typ
import bpy
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
@ -13,120 +16,465 @@ log = logger.get(__name__)
class OperateMathNode(base.MaxwellSimNode):
r"""Applies a function that depends on two inputs.
Attributes:
category: The category of operations to apply to the inputs.
**Only valid** categories can be chosen.
operation: The actual operation to apply to the inputs.
**Only valid** operations can be chosen.
"""
node_type = ct.NodeType.OperateMath
bl_label = 'Operate Math'
input_socket_sets: typ.ClassVar = {
'Elementwise': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
'Expr | Expr': {
'Expr L': sockets.ExprSocketDef(),
'Expr R': sockets.ExprSocketDef(),
},
## TODO: Filter-array building operations
'Vec-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
'Data | Data': {
'Data L': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
'Mat-Vec': {
'Data L': sockets.AnySocketDef(),
'Data R': sockets.AnySocketDef(),
'Expr | Data': {
'Expr L': sockets.ExprSocketDef(),
'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
}
output_sockets: typ.ClassVar = {
'Data': sockets.AnySocketDef(),
output_socket_sets: typ.ClassVar = {
'Expr | Expr': {
'Expr': sockets.ExprSocketDef(),
},
'Data | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
'Expr | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
}
####################
# - Properties
####################
operation: bpy.props.EnumProperty(
name='Op',
description='Operation to apply to the two inputs',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.on_prop_changed('operation', context),
category: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_categories()
)
def search_operations(self) -> list[tuple[str, str, str]]:
operation: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
)
def search_categories(self) -> list[ct.BLEnumElement]:
"""Deduce and return a list of valid categories for the current socket set and input data."""
data_l_info = self._compute_input(
'Data L', kind=ct.FlowKind.Info, optional=True
)
data_r_info = self._compute_input(
'Data R', kind=ct.FlowKind.Info, optional=True
)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
# Categories by Socket Set
NUMBER_NUMBER = (
'Number | Number',
'Number | Number',
'Operations between numerical elements',
)
NUMBER_VECTOR = (
'Number | Vector',
'Number | Vector',
'Operations between numerical and vector elements',
)
NUMBER_MATRIX = (
'Number | Matrix',
'Number | Matrix',
'Operations between numerical and matrix elements',
)
VECTOR_VECTOR = (
'Vector | Vector',
'Vector | Vector',
'Operations between vector elements',
)
MATRIX_VECTOR = (
'Matrix | Vector',
'Matrix | Vector',
'Operations between vector and matrix elements',
)
MATRIX_MATRIX = (
'Matrix | Matrix',
'Matrix | Matrix',
'Operations between matrix elements',
)
categories = []
## Expr | Expr
if self.active_socket_set == 'Expr | Expr':
return [NUMBER_NUMBER]
## Data | Data
if (
self.active_socket_set == 'Data | Data'
and has_data_l_info
and has_data_r_info
):
# Check Valid Broadcasting
## Number | Number
if data_l_info.output_shape is None and data_r_info.output_shape is None:
categories = [NUMBER_NUMBER]
## Number | Vector
elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 1
):
categories = [NUMBER_VECTOR]
## Number | Matrix
elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 2
): # noqa: PLR2004
categories = [NUMBER_MATRIX]
## Vector | Vector
elif (
len(data_l_info.output_shape) == 1
and len(data_r_info.output_shape) == 1
):
categories = [VECTOR_VECTOR]
## Matrix | Vector
elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 1
):
categories = [MATRIX_VECTOR]
## Matrix | Matrix
elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 2 # noqa: PLR2004
):
categories = [MATRIX_MATRIX]
## Expr | Data
if self.active_socket_set == 'Expr | Data' and has_data_r_info:
if data_r_info.output_shape is None:
categories = [NUMBER_NUMBER]
else:
categories = {
1: [NUMBER_NUMBER, NUMBER_VECTOR],
2: [NUMBER_NUMBER, NUMBER_MATRIX],
}[len(data_r_info.output_shape)]
return [
(*category, '', i) if category is not None else None
for i, category in enumerate(categories)
]
def search_operations(self) -> list[ct.BLEnumElement]:
items = []
if self.active_socket_set == 'Elementwise':
items = [
('ADD', 'Add', 'L + R (by el)'),
('SUB', 'Subtract', 'L - R (by el)'),
('MUL', 'Multiply', 'L · R (by el)'),
('DIV', 'Divide', 'L ÷ R (by el)'),
('POW', 'Power', 'L^R (by el)'),
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'),
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']:
items += [
('ADD', 'L + R', 'Add'),
('SUB', 'L - R', 'Subtract'),
('MUL', 'L · R', 'Multiply'),
('DIV', 'L ÷ R', 'Divide'),
('POW', 'L^R', 'Power'),
('ATAN2', 'atan2(L,R)', 'atan2(L,R)'),
]
elif self.active_socket_set in 'Vec | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('CROSS', 'Cross', 'L x R (by last-axis'),
if self.category in 'Vector | Vector':
if items:
items += [None]
items += [
('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'),
('CROSS', 'L x R', 'Cross Product'),
('PROJ', 'proj(L, R)', 'Projection'),
]
elif self.active_socket_set == 'Mat | Vec':
items = [
('DOT', 'Dot', 'L · R'),
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'),
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'),
if self.category == 'Matrix | Vector':
if items:
items += [None]
items += [
('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'),
('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'),
('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'),
]
if self.category == 'Matrix | Matrix':
if items:
items += [None]
items += [
('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'),
]
return [
(*item, '', i) if item is not None else None for i, item in enumerate(items)
]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, 'operation')
layout.prop(self, self.blfields['category'], text='')
layout.prop(self, self.blfields['operation'], text='')
####################
# - Properties
# - Events
####################
@events.on_value_changed(
# Trigger
socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'},
prop_name='active_socket_set',
run_on_init=True,
)
def on_socket_set_changed(self) -> None:
# Recompute Valid Categories
self.category = bl_cache.Signal.ResetEnumItems
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
prop_name='category',
run_on_init=True,
)
def on_category_changed(self) -> None:
self.operation = bl_cache.Signal.ResetEnumItems
####################
# - Output
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
)
def compute_expr(self, props: dict, input_sockets: dict):
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
return {
'ADD': lambda: expr_l + expr_r,
'SUB': lambda: expr_l - expr_r,
'MUL': lambda: expr_l * expr_r,
'DIV': lambda: expr_l / expr_r,
'POW': lambda: expr_l**expr_r,
'ATAN2': lambda: sp.atan2(expr_r, expr_l),
}[props['operation']]()
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
props={'operation'},
input_sockets={'Data L', 'Data R'},
input_socket_kinds={
'Data L': ct.FlowKind.LazyValueFunc,
'Data R': ct.FlowKind.LazyValueFunc,
},
input_sockets_optional={
'Data L': True,
'Data R': True,
},
)
def compute_data(self, props: dict, input_sockets: dict):
data_l = input_sockets['Data L']
data_r = input_sockets['Data R']
has_data_l = not ct.FlowSignal.check(data_l)
mapping_func = {
# Number | *
'ADD': lambda datas: datas[0] + datas[1],
'SUB': lambda datas: datas[0] - datas[1],
'MUL': lambda datas: datas[0] * datas[1],
'DIV': lambda datas: datas[0] / datas[1],
'POW': lambda datas: datas[0] ** datas[1],
'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]),
# Vector | Vector
'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'CROSS': lambda datas: jnp.cross(datas[0], datas[1]),
# Matrix | Vector
'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]),
'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]),
# Matrix | Matrix
'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
}[props['operation']]
# Compose by Socket Set
## Data | Data
if has_data_l:
return (data_l | data_r).compose_within(
mapping_func,
supports_jax=True,
)
## Expr | Data
expr_l_lazy_value_func = ct.LazyValueFuncFlow(
func=lambda expr_l_value: expr_l_value,
func_args=[typ.Any],
supports_jax=True,
)
return (expr_l_lazy_value_func | data_r).compose_within(
mapping_func,
supports_jax=True,
)
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
output_sockets={'Data'},
output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs
),
unit=None,
)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Params
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Params,
props={'operation'},
input_sockets={'Data L', 'Data R'},
input_sockets={'Expr L', 'Data L', 'Data R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Data L': {ct.FlowKind.Info, ct.FlowKind.Params},
'Data R': {ct.FlowKind.Info, ct.FlowKind.Params},
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
},
)
def compute_data(self, props: dict, input_sockets: dict):
if self.active_socket_set == 'Elementwise':
# Element-Wise Arithmetic
if props['operation'] == 'ADD':
return input_sockets['Data L'] + input_sockets['Data R']
if props['operation'] == 'SUB':
return input_sockets['Data L'] - input_sockets['Data R']
if props['operation'] == 'MUL':
return input_sockets['Data L'] * input_sockets['Data R']
if props['operation'] == 'DIV':
return input_sockets['Data L'] / input_sockets['Data R']
def compute_data_params(
self, props, input_sockets
) -> ct.ParamsFlow | ct.FlowSignal:
expr_l = input_sockets['Expr L']
data_l_info = input_sockets['Data L'][ct.FlowKind.Info]
data_l_params = input_sockets['Data L'][ct.FlowKind.Params]
data_r_info = input_sockets['Data R'][ct.FlowKind.Info]
data_r_params = input_sockets['Data R'][ct.FlowKind.Params]
# Element-Wise Arithmetic
if props['operation'] == 'POW':
return input_sockets['Data L'] ** input_sockets['Data R']
has_expr_l = not ct.FlowSignal.check(expr_l)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_l_params = not ct.FlowSignal.check(data_l_params)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
has_data_r_params = not ct.FlowSignal.check(data_r_params)
# Binary Trigonometry
if props['operation'] == 'ATAN2':
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
#log.critical((props, input_sockets))
# Special Functions
if props['operation'] == 'HEAVISIDE':
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R'])
# Compose by Socket Set
## Data | Data
if (
has_data_l_info
and has_data_l_params
and has_data_r_info
and has_data_r_params
):
return data_l_params | data_r_params
# Linear Algebra
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}:
if props['operation'] == 'DOT':
return jnp.dot(input_sockets['Data L'], input_sockets['Data R'])
## Expr | Data
if has_expr_l and has_data_r_info and has_data_r_params:
operation = props['operation']
data_unit = data_r_info.output_unit
elif self.active_socket_set == 'Vec-Vec':
if props['operation'] == 'CROSS':
return jnp.cross(input_sockets['Data L'], input_sockets['Data R'])
# By Operation
## Add/Sub: Scale to Output Unit
if operation in ['ADD', 'SUB', 'MUL', 'DIV']:
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
elif self.active_socket_set == 'Mat-Vec':
if props['operation'] == 'LIN_SOLVE':
return jnp.linalg.lstsq(
input_sockets['Data L'], input_sockets['Data R']
)
if props['operation'] == 'LSQ_SOLVE':
return jnp.linalg.solve(
input_sockets['Data L'], input_sockets['Data R']
return data_r_params.compose_within(
enclosing_func_args=[value],
)
msg = 'Invalid operation'
raise ValueError(msg)
## Pow: Doesn't Exist (?)
## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function
if operation == 'POW':
return ct.FlowSignal.FlowPending
## atan2(): Only Length
## -> Implicitly presume that Data L/R use length units.
if operation == 'ATAN2':
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
return data_r_params.compose_within(
enclosing_func_args=[value],
)
return data_r_params.compose_within(
enclosing_func_args=[
spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
]
)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Info
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
input_sockets={'Expr L', 'Data L', 'Data R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Data L': ct.FlowKind.Info,
'Data R': ct.FlowKind.Info,
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
},
)
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow:
expr_l = input_sockets['Expr L']
data_l_info = input_sockets['Data L']
data_r_info = input_sockets['Data R']
has_expr_l = not ct.FlowSignal.check(expr_l)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
# Info by Socket Set
## Data | Data
if has_data_l_info and has_data_r_info:
return data_r_info
## Expr | Data
if has_expr_l and has_data_r_info:
return data_r_info
return ct.FlowSignal.FlowPending
####################

View File

@ -0,0 +1,165 @@
"""Declares `TransformMathNode`."""
import enum
import typing as typ
import bpy
import jax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
The shape, type, and interpretation of the input/output data is dynamically shown.
# Socket Sets
## Interpret
Reinterprets the `InfoFlow` of an array, **without changing it**.
Attributes:
operation: Operation to apply to the input.
"""
node_type = ct.NodeType.TransformMath
bl_label = 'Transform Math'
input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'),
}
input_socket_sets: typ.ClassVar = {
'Fourier': {},
'Affine': {},
'Convolve': {},
}
output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'),
}
####################
# - Properties
####################
operation: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.active_socket_set == 'Fourier': # noqa: SIM114
items = []
elif self.active_socket_set == 'Affine': # noqa: SIM114
items = []
elif self.active_socket_set == 'Convolve':
items = []
else:
msg = f'Active socket set {self.active_socket_set} is unknown'
raise RuntimeError(msg)
return [(*item, '', i) for i, item in enumerate(items)]
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
####################
# - Events
####################
@events.on_value_changed(
prop_name='active_socket_set',
)
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
####################
# - Compute: LazyValueFunc / Array
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
props={'active_socket_set', 'operation'},
input_sockets={'Data'},
input_socket_kinds={
'Data': ct.FlowKind.LazyValueFunc,
},
)
def compute_data(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data'])
if not has_data or props['operation'] == 'NONE':
return ct.FlowSignal.FlowPending
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
'Fourier': {},
'Affine': {},
'Convolve': {},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data
return input_sockets['Data'].compose_within(
mapping_func,
supports_jax=True,
)
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
output_sockets={'Data'},
output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs
),
unit=None,
)
return ct.FlowSignal.FlowPending
####################
# - Compute Auxiliary: Info / Params
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
)
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
info = input_sockets['Data']
if ct.FlowSignal.check(info):
return ct.FlowSignal.FlowPending
return info
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Params,
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Params},
)
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
return input_sockets['Data']
####################
# - Blender Registration
####################
BL_REGISTER = [
TransformMathNode,
]
BL_NODES = {ct.NodeType.TransformMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -61,32 +61,34 @@ class VizMode(enum.StrEnum):
@staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
EMPTY = ()
Z = spux.MathType.Integer
R = spux.MathType.Real
VM = VizMode
valid_viz_modes = {
((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D],
((spux.MathType.Integer), (spux.MathType.Real)): [
VizMode.Hist1D,
VizMode.BoxPlot1D,
(EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
((Z), (None, R)): [
VM.Hist1D,
VM.BoxPlot1D,
],
((spux.MathType.Real,), (spux.MathType.Real,)): [
VizMode.Curve2D,
VizMode.Points2D,
VizMode.Bar,
((R,), (None, R)): [
VM.Curve2D,
VM.Points2D,
VM.Bar,
],
((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [
VizMode.Curves2D,
VizMode.FilledCurves2D,
((R, Z), (None, R)): [
VM.Curves2D,
VM.FilledCurves2D,
],
((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [
VizMode.Heatmap2D,
((R, R), (None, R)): [
VM.Heatmap2D,
],
(
(spux.MathType.Real, spux.MathType.Real, spux.MathType.Real),
(spux.MathType.Real,),
): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D],
((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
}.get(
(
tuple(info.dim_mathtypes.values()),
tuple(info.output_mathtypes.values()),
(info.output_shape, info.output_mathtype),
)
)
@ -161,10 +163,10 @@ class VizTarget(enum.StrEnum):
@staticmethod
def to_name(value: typ.Self) -> str:
return {
VizTarget.Plot2D: 'Image (Plot)',
VizTarget.Pixels: 'Image (Pixels)',
VizTarget.PixelsPlane: 'Image (Plane)',
VizTarget.Voxels: '3D Field',
VizTarget.Plot2D: 'Plot',
VizTarget.Pixels: 'Pixels',
VizTarget.PixelsPlane: 'Image Plane',
VizTarget.Voxels: 'Voxels',
}[value]
@staticmethod

View File

@ -1036,6 +1036,12 @@ class MaxwellSimNode(bpy.types.Node):
# Generate New Instance ID
self.reset_instance_id()
# Generate New Instance ID for Sockets
## Sockets can't do this themselves.
for bl_sockets in [self.inputs, self.outputs]:
for bl_socket in bl_sockets:
bl_socket.reset_instance_id()
# Generate New Sim Node Name
## Blender will automatically add .001 so that `self.name` is unique.
self.sim_node_name = self.name

View File

@ -14,7 +14,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
bl_label = 'Scientific Constant'
output_sockets: typ.ClassVar = {
'Value': sockets.AnySocketDef(),
'Value': sockets.ExprSocketDef(),
}
####################

View File

@ -913,10 +913,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
col = layout.column()
row = col.row()
row.alignment = 'RIGHT'
if self.is_linked:
self.draw_output_label_row(row, text)
else:
row.label(text=text)
# Draw FlowKind.Info related Information
if self.use_info_draw:

View File

@ -12,6 +12,10 @@ from .. import base
log = logger.get(__name__)
def unicode_superscript(n):
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
class DataInfoColumn(enum.StrEnum):
Length = enum.auto()
MathType = enum.auto()
@ -49,11 +53,11 @@ class DataBLSocket(base.MaxwellSimSocket):
## TODO: typ.Literal['xarray', 'jax']
show_info_columns: bool = bl_cache.BLField(
False,
True,
prop_ui=True,
)
info_columns: DataInfoColumn = bl_cache.BLField(
{DataInfoColumn.MathType, DataInfoColumn.Length}, prop_ui=True, enum_many=True
{DataInfoColumn.MathType, DataInfoColumn.Unit}, prop_ui=True, enum_many=True
)
####################
@ -71,8 +75,10 @@ class DataBLSocket(base.MaxwellSimSocket):
# - UI
####################
def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None:
if self.format == 'jax':
row.label(text=text)
info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
row.prop(self, self.blfields['info_columns'])
row.prop(
self,
@ -83,7 +89,8 @@ class DataBLSocket(base.MaxwellSimSocket):
)
def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None:
if self.format == 'jax':
info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
row.prop(
self,
self.blfields['show_info_columns'],
@ -92,6 +99,7 @@ class DataBLSocket(base.MaxwellSimSocket):
icon=ct.Icon.ToggleSocketInfo,
)
row.prop(self, self.blfields['info_columns'])
row.label(text=text)
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
@ -118,16 +126,27 @@ class DataBLSocket(base.MaxwellSimSocket):
grid.label(text=spux.sp_to_str(dim_idx.unit))
# Outputs
for output_name in info.output_names:
grid.label(text=output_name)
grid.label(text=info.output_name)
if DataInfoColumn.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput)
if DataInfoColumn.MathType in self.info_columns:
grid.label(
text=spux.MathType.to_str(info.output_mathtypes[output_name])
text=(
spux.MathType.to_str(info.output_mathtype)
+ (
'ˣ'.join(
[
unicode_superscript(out_axis)
for out_axis in info.output_shape
]
)
if info.output_shape
else ''
)
)
)
if DataInfoColumn.Unit in self.info_columns:
grid.label(text=spux.sp_to_str(info.output_units[output_name]))
grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
####################
@ -137,9 +156,11 @@ class DataSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.Data
format: typ.Literal['xarray', 'jax', 'monitor_data']
default_show_info_columns: bool = True
def init(self, bl_socket: DataBLSocket) -> None:
bl_socket.format = self.format
bl_socket.default_show_info_columns = self.default_show_info_columns
####################

View File

@ -86,9 +86,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
def lazy_value_func(self) -> ct.LazyValueFuncFlow:
return ct.LazyValueFuncFlow(
func=sp.lambdify(self.symbols, self.value, 'jax'),
func_args=[
(sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols
],
func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols],
supports_jax=True,
)

View File

@ -31,6 +31,18 @@ class MathType(enum.StrEnum):
Real = enum.auto()
Complex = enum.auto()
def combine(*mathtypes: list[typ.Self]) -> typ.Self:
if MathType.Complex in mathtypes:
return MathType.Complex
elif MathType.Real in mathtypes:
return MathType.Real
elif MathType.Rational in mathtypes:
return MathType.Rational
elif MathType.Integer in mathtypes:
return MathType.Integer
elif MathType.Bool in mathtypes:
return MathType.Bool
@staticmethod
def from_expr(sp_obj: SympyType) -> type:
if isinstance(sp_obj, sp.logic.boolalg.Boolean):
@ -52,13 +64,13 @@ class MathType(enum.StrEnum):
int: MathType.Integer,
float: MathType.Real,
complex: MathType.Complex,
#jnp.int32: MathType.Integer,
#jnp.int64: MathType.Integer,
#jnp.float32: MathType.Real,
#jnp.float64: MathType.Real,
#jnp.complex64: MathType.Complex,
#jnp.complex128: MathType.Complex,
#jnp.bool_: MathType.Bool,
# jnp.int32: MathType.Integer,
# jnp.int64: MathType.Integer,
# jnp.float32: MathType.Real,
# jnp.float64: MathType.Real,
# jnp.complex64: MathType.Complex,
# jnp.complex128: MathType.Complex,
# jnp.bool_: MathType.Bool,
}[dtype]
@staticmethod
@ -595,6 +607,21 @@ ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
)
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
# Number
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
# Vector
Real3DVector: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,

View File

@ -117,8 +117,8 @@ def rgba_image_from_2d_map(
def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None:
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram')
@ -130,8 +130,8 @@ def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None:
x_name = info.dim_names[0]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
ax.boxplot(data)
ax.set_title('Box Plot')
@ -147,8 +147,8 @@ def plot_curve_2d(
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data)
@ -167,8 +167,8 @@ def plot_points_2d(
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
ax.set_title('2D Points')
@ -179,8 +179,8 @@ def plot_points_2d(
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar')
@ -194,8 +194,8 @@ def plot_curves_2d(
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
for category in range(data.shape[1]):
ax.plot(data[:, 0], data[:, 1])
@ -211,8 +211,8 @@ def plot_filled_curves_2d(
) -> None:
x_name = info.dim_names[0]
x_unit = info.dim_units[x_name]
y_name = info.output_names[0]
y_unit = info.output_units[y_name]
y_name = info.output_name
y_unit = info.output_unit
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
ax.set_title('2D Curves')