feat: Implemented operate math node.

main
Sofus Albert Høgsbro Rose 2024-04-26 17:22:55 +02:00
parent 7d944a704e
commit b2a7eefb45
Signed by: so-rose
GPG Key ID: AD901CB0F3701434
18 changed files with 1170 additions and 390 deletions

View File

@ -14,6 +14,8 @@
- [x] Extract - [x] Extract
- [x] Viz - [x] Viz
- [x] Math / Map Math - [x] Math / Map Math
- [ ] Remove "By x" socket set let socket sets only be "Function"/"Expr"; then add a dynamic enum underneath to select "By x" based on data support.
- [ ] Filter the operations based on data support, ex. use positive-definiteness to guide cholesky.
- [x] Math / Filter Math - [x] Math / Filter Math
- [ ] Math / Reduce Math - [ ] Math / Reduce Math
- [ ] Math / Operate Math - [ ] Math / Operate Math
@ -34,8 +36,6 @@
- [x] Constants / Blender Constant - [x] Constants / Blender Constant
- [ ] Web / Tidy3D Web Importer - [ ] Web / Tidy3D Web Importer
- [ ] Change to output only a `FilePath`, which can be plugged into a Tidy3D File Importer.
- [ ] Implement caching, such that the file will only download if the file doesn't already exist.
- [ ] Have a visual indicator for the current download status, with a manual re-download button. - [ ] Have a visual indicator for the current download status, with a manual re-download button.
- [x] File Import / Material Import - [x] File Import / Material Import

View File

@ -68,6 +68,9 @@ class FlowKind(enum.StrEnum):
if kind == cls.LazyArrayRange: if kind == cls.LazyArrayRange:
return value.rescale_to_unit(unit_system[socket_type]) return value.rescale_to_unit(unit_system[socket_type])
if kind == cls.Params:
return value.rescale_to_unit(unit_system[socket_type])
msg = 'Tried to scale unknown kind' msg = 'Tried to scale unknown kind'
raise ValueError(msg) raise ValueError(msg)
@ -322,6 +325,28 @@ class LazyValueFuncFlow:
supports_jax: bool = False supports_jax: bool = False
supports_numba: bool = False supports_numba: bool = False
# Merging
def __or__(
self,
other: typ.Self,
):
return LazyValueFuncFlow(
func=lambda *args, **kwargs: (
self.func(
*list(args[: len(self.func_args)]),
**{k: v for k, v in kwargs.items() if k in self.func_kwargs},
),
other.func(
*list(args[len(self.func_args) :]),
**{k: v for k, v in kwargs.items() if k in other.func_kwargs},
),
),
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
supports_jax=self.supports_jax and other.supports_jax,
supports_numba=self.supports_numba and other.supports_numba,
)
# Composition # Composition
def compose_within( def compose_within(
self, self,
@ -691,10 +716,21 @@ class ParamsFlow:
func_args: list[typ.Any] = dataclasses.field(default_factory=list) func_args: list[typ.Any] = dataclasses.field(default_factory=list)
func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict) func_kwargs: dict[str, typ.Any] = dataclasses.field(default_factory=dict)
def __or__(
self,
other: typ.Self,
):
return ParamsFlow(
func_args=self.func_args + other.func_args,
func_kwargs=self.func_kwargs | other.func_kwargs,
)
def compose_within( def compose_within(
self, self,
enclosing_func_args: list[tuple[type]] = (), enclosing_func_args: list[tuple[type]] = (),
enclosing_func_kwargs: dict[str, type] = MappingProxyType({}), enclosing_func_kwargs: dict[str, type] = MappingProxyType({}),
enclosing_func_arg_units: dict[str, type] = MappingProxyType({}),
enclosing_func_kwarg_units: dict[str, type] = MappingProxyType({}),
) -> typ.Self: ) -> typ.Self:
return ParamsFlow( return ParamsFlow(
func_args=self.func_args + list(enclosing_func_args), func_args=self.func_args + list(enclosing_func_args),
@ -718,26 +754,133 @@ class InfoFlow:
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property @functools.cached_property
def dim_mathtypes(self) -> dict[str, int]: def dim_mathtypes(self) -> dict[str, spux.MathType]:
return { return {
dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items() dim_name: dim_idx.mathtype for dim_name, dim_idx in self.dim_idx.items()
} }
@functools.cached_property @functools.cached_property
def dim_units(self) -> dict[str, int]: def dim_units(self) -> dict[str, spux.Unit]:
return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()}
@functools.cached_property @functools.cached_property
def dim_idx_arrays(self) -> list[ArrayFlow]: def dim_idx_arrays(self) -> list[jax.Array]:
return [ return [
dim_idx.realize().values dim_idx.realize().values
if isinstance(dim_idx, LazyArrayRangeFlow) if isinstance(dim_idx, LazyArrayRangeFlow)
else dim_idx.values else dim_idx.values
for dim_idx in self.dim_idx.values() for dim_idx in self.dim_idx.values()
] ]
return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()}
# Output Information # Output Information
output_names: list[str] = dataclasses.field(default_factory=list) output_name: str = dataclasses.field(default_factory=list)
output_mathtypes: dict[str, spux.MathType] = dataclasses.field(default_factory=dict) output_shape: tuple[int, ...] | None = dataclasses.field(default=None)
output_units: dict[str, spux.Unit | None] = dataclasses.field(default_factory=dict) output_mathtype: spux.MathType = dataclasses.field()
output_unit: spux.Unit | None = dataclasses.field()
# Pinned Dimension Information
pinned_dim_names: list[str] = dataclasses.field(default_factory=list)
pinned_dim_values: dict[str, float | complex] = dataclasses.field(
default_factory=dict
)
pinned_dim_mathtypes: dict[str, spux.MathType] = dataclasses.field(
default_factory=dict
)
pinned_dim_units: dict[str, spux.Unit] = dataclasses.field(default_factory=dict)
####################
# - Methods
####################
def delete_dimension(self, dim_name: str) -> typ.Self:
"""Delete a dimension."""
return InfoFlow(
# Dimensions
dim_names=[
_dim_name for _dim_name in self.dim_names if _dim_name != dim_name
],
dim_idx={
_dim_name: dim_idx
for _dim_name, dim_idx in self.dim_idx.items()
if _dim_name != dim_name
},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def swap_dimensions(self, dim_0_name: str, dim_1_name: str) -> typ.Self:
"""Delete a dimension."""
# Compute Swapped Dimension Name List
def name_swapper(dim_name):
return (
dim_name
if dim_name not in [dim_0_name, dim_1_name]
else {dim_0_name: dim_1_name, dim_1_name: dim_0_name}[dim_name]
)
dim_names = [name_swapper(dim_name) for dim_name in self.dim_names]
# Compute Info
return InfoFlow(
# Dimensions
dim_names=dim_names,
dim_idx={dim_name: self.dim_idx[dim_name] for dim_name in dim_names},
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)
def set_output_mathtype(self, output_mathtype: spux.MathType) -> typ.Self:
"""Set the MathType of a particular output name."""
return InfoFlow(
dim_names=self.dim_names,
dim_idx=self.dim_idx,
# Outputs
output_name=self.output_name,
output_shape=self.output_shape,
output_mathtype=output_mathtype,
output_unit=self.output_unit,
)
def collapse_output(
self,
collapsed_name: str,
collapsed_mathtype: spux.MathType,
collapsed_unit: spux.Unit,
) -> typ.Self:
return InfoFlow(
# Dimensions
dim_names=self.dim_names,
dim_idx=self.dim_idx,
output_name=collapsed_name,
output_shape=None,
output_mathtype=collapsed_mathtype,
output_unit=collapsed_unit,
)
@functools.cached_property
def shift_last_input(self):
"""Shift the last input dimension to the output."""
return InfoFlow(
# Dimensions
dim_names=self.dim_names[:-1],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in self.dim_idx.items()
if dim_name != self.dim_names[-1]
},
# Outputs
output_name=self.output_name,
output_shape=(
(self.dim_lens[self.dim_names[-1]],)
if self.output_shape is None
else (self.dim_lens[self.dim_names[-1]], *self.output_shape)
),
output_mathtype=self.output_mathtype,
output_unit=self.output_unit,
)

View File

@ -15,6 +15,7 @@ class NodeType(blender_type_enum.BlenderTypeEnum):
FilterMath = enum.auto() FilterMath = enum.auto()
ReduceMath = enum.auto() ReduceMath = enum.auto()
OperateMath = enum.auto() OperateMath = enum.auto()
TransformMath = enum.auto()
# Inputs # Inputs
WaveConstant = enum.auto() WaveConstant = enum.auto()

View File

@ -260,7 +260,9 @@ SOCKET_UNITS = {
} }
def unit_to_socket_type(unit: spux.Unit) -> ST: def unit_to_socket_type(
unit: spux.Unit | None, fallback_mathtype: spux.MathType | None = None
) -> ST:
"""Returns a SocketType that accepts the given unit. """Returns a SocketType that accepts the given unit.
Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned. Only the unit-compatibility is taken into account; in the case of overlap, several the ordering of `SOCKET_UNITS` determines which is returned.
@ -269,6 +271,14 @@ def unit_to_socket_type(unit: spux.Unit) -> ST:
Returns: Returns:
**The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility. **The first `SocketType` in `SOCKET_UNITS`, which contains the given unit as a valid possibility.
""" """
if unit is None and fallback_mathtype is not None:
return {
spux.MathType.Integer: ST.IntegerNumber,
spux.MathType.Rational: ST.RationalNumber,
spux.MathType.Real: ST.RealNumber,
spux.MathType.Complex: ST.ComplexNumber,
}[fallback_mathtype]
for socket_type, _units in SOCKET_UNITS.items(): for socket_type, _units in SOCKET_UNITS.items():
if unit in _units['values'].values(): if unit in _units['values'].values():
return socket_type return socket_type

View File

@ -215,6 +215,15 @@ class ExtractDataNode(base.MaxwellSimNode):
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self):
has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_components is not None
if has_sim_data or has_monitor_data:
return f'Extract: {self.extract_filter}'
return self.bl_label
def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Draw node properties in the node. """Draw node properties in the node.
@ -223,43 +232,6 @@ class ExtractDataNode(base.MaxwellSimNode):
""" """
col.prop(self, self.blfields['extract_filter'], text='') col.prop(self, self.blfields['extract_filter'], text='')
def draw_info(self, _: bpy.types.Context, col: bpy.types.UILayout) -> None:
"""Draw dynamic information in the node, for user consideration.
Parameters:
col: UI target for drawing.
"""
has_sim_data = self.sim_data_monitor_nametype is not None
has_monitor_data = self.monitor_data_components is not None
if has_sim_data or has_monitor_data:
# Header
row = col.row()
row.alignment = 'CENTER'
if has_sim_data:
row.label(text=f'{len(self.sim_data_monitor_nametype)} Monitors')
elif has_monitor_data:
row.label(text=f'{self.monitor_data_type} Monitor Data')
# Monitor Data Contents
## TODO: More compact double-split
## TODO: Output shape data.
## TODO: Local ENUM_MANY tabs for visible column selection?
row = col.row()
box = row.box()
grid = box.grid_flow(row_major=True, columns=2, even_columns=True)
if has_sim_data:
for (
monitor_name,
monitor_type,
) in self.sim_data_monitor_nametype.items():
grid.label(text=monitor_name)
grid.label(text=monitor_type.replace('Data', ''))
elif has_monitor_data:
for component_name in self.monitor_data_components:
grid.label(text=component_name)
grid.label(text=self.monitor_data_type)
#################### ####################
# - Events # - Events
#################### ####################
@ -416,9 +388,8 @@ class ExtractDataNode(base.MaxwellSimNode):
else: else:
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
info_output_names = { info_output_name = props['extract_filter']
'output_names': [props['extract_filter']], info_output_shape = None
}
# Compute InfoFlow from XArray # Compute InfoFlow from XArray
## XYZF: Field / Permittivity / FieldProjectionCartesian ## XYZF: Field / Permittivity / FieldProjectionCartesian
@ -442,13 +413,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Complex}, output_shape=None,
output_units={ output_mathtype=spux.MathType.Complex,
props['extract_filter']: spu.volt / spu.micrometer output_unit=(
spu.volt / spu.micrometer
if props['monitor_data_type'] == 'Field' if props['monitor_data_type'] == 'Field'
else None else None
}, ),
) )
## XYZT: FieldTime ## XYZT: FieldTime
@ -468,17 +440,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Complex}, output_shape=None,
output_units={ output_mathtype=spux.MathType.Complex,
props['extract_filter']: ( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer
)
if props['monitor_data_type'] == 'Field' if props['monitor_data_type'] == 'Field'
else None else None
}, ),
) )
## F: Flux ## F: Flux
@ -492,9 +461,10 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Real}, output_shape=None,
output_units={props['extract_filter']: spu.watt}, output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
) )
## T: FluxTime ## T: FluxTime
@ -508,9 +478,10 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Real}, output_shape=None,
output_units={props['extract_filter']: spu.watt}, output_mathtype=spux.MathType.Real,
output_unit=spu.watt,
) )
## RThetaPhiF: FieldProjectionAngle ## RThetaPhiF: FieldProjectionAngle
@ -537,15 +508,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Real}, output_shape=None,
output_units={ output_mathtype=spux.MathType.Real,
props['extract_filter']: ( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
) ),
},
) )
## UxUyRF: FieldProjectionKSpace ## UxUyRF: FieldProjectionKSpace
@ -570,15 +540,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Real}, output_shape=None,
output_units={ output_mathtype=spux.MathType.Real,
props['extract_filter']: ( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
) ),
},
) )
## OrderxOrderyF: Diffraction ## OrderxOrderyF: Diffraction
@ -600,15 +569,14 @@ class ExtractDataNode(base.MaxwellSimNode):
is_sorted=True, is_sorted=True,
), ),
}, },
**info_output_names, output_name=props['extract_filter'],
output_mathtypes={props['extract_filter']: spux.MathType.Real}, output_shape=None,
output_units={ output_mathtype=spux.MathType.Real,
props['extract_filter']: ( output_unit=(
spu.volt / spu.micrometer spu.volt / spu.micrometer
if props['extract_filter'].startswith('E') if props['extract_filter'].startswith('E')
else spu.ampere / spu.micrometer else spu.ampere / spu.micrometer
) ),
},
) )
msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"' msg = f'Unsupported Monitor Data Type {props["monitor_data_type"]} in "FlowKind.Info" of "{self.bl_label}"'

View File

@ -1,14 +1,16 @@
from . import filter_math, map_math, operate_math, reduce_math from . import filter_math, map_math, operate_math, reduce_math, transform_math
BL_REGISTER = [ BL_REGISTER = [
*map_math.BL_REGISTER, *map_math.BL_REGISTER,
*filter_math.BL_REGISTER, *filter_math.BL_REGISTER,
*reduce_math.BL_REGISTER, *reduce_math.BL_REGISTER,
*operate_math.BL_REGISTER, *operate_math.BL_REGISTER,
*transform_math.BL_REGISTER,
] ]
BL_NODES = { BL_NODES = {
**map_math.BL_NODES, **map_math.BL_NODES,
**filter_math.BL_NODES, **filter_math.BL_NODES,
**reduce_math.BL_NODES, **reduce_math.BL_NODES,
**operate_math.BL_NODES, **operate_math.BL_NODES,
**transform_math.BL_NODES,
} }

View File

@ -1,3 +1,5 @@
"""Declares `FilterMathNode`."""
import enum import enum
import typing as typ import typing as typ
@ -15,11 +17,21 @@ log = logger.get(__name__)
class FilterMathNode(base.MaxwellSimNode): class FilterMathNode(base.MaxwellSimNode):
"""Reduces the dimensionality of data. r"""Applies a function that operates on the shape of the array.
The shape, type, and interpretation of the input/output data is dynamically shown.
# Socket Sets
## Dimensions
Alter the dimensions of the array.
## Interpret
Only alter the interpretation of the array data, which guides what it can be used for.
These operations are **zero cost**, since the data itself is untouched.
Attributes: Attributes:
operation: Operation to apply to the input. operation: Operation to apply to the input.
dim: Dims to use when filtering data
""" """
node_type = ct.NodeType.FilterMath node_type = ct.NodeType.FilterMath
@ -29,8 +41,8 @@ class FilterMathNode(base.MaxwellSimNode):
'Data': sockets.DataSocketDef(format='jax'), 'Data': sockets.DataSocketDef(format='jax'),
} }
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'By Dim': {}, 'Interpret': {},
'By Dim Value': {}, 'Dimensions': {},
} }
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'), 'Data': sockets.DataSocketDef(format='jax'),
@ -43,10 +55,17 @@ class FilterMathNode(base.MaxwellSimNode):
prop_ui=True, enum_cb=lambda self, _: self.search_operations() prop_ui=True, enum_cb=lambda self, _: self.search_operations()
) )
dim: enum.Enum = bl_cache.BLField( # Dimension Selection
dim_0: enum.Enum = bl_cache.BLField(
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
)
dim_1: enum.Enum = bl_cache.BLField(
None, prop_ui=True, enum_cb=lambda self, _: self.search_dims() None, prop_ui=True, enum_cb=lambda self, _: self.search_dims()
) )
####################
# - Computed
####################
@property @property
def data_info(self) -> ct.InfoFlow | None: def data_info(self) -> ct.InfoFlow | None:
info = self._compute_input('Data', kind=ct.FlowKind.Info) info = self._compute_input('Data', kind=ct.FlowKind.Info)
@ -60,87 +79,119 @@ class FilterMathNode(base.MaxwellSimNode):
#################### ####################
def search_operations(self) -> list[tuple[str, str, str]]: def search_operations(self) -> list[tuple[str, str, str]]:
items = [] items = []
if self.active_socket_set == 'By Dim': if self.active_socket_set == 'Interpret':
items += [ items += [
('SQUEEZE', 'del a | #=1', 'Squeeze'), ('DIM_TO_VEC', '→ Vector', 'Shift last dimension to output.'),
('DIMS_TO_MAT', '→ Matrix', 'Shift last 2 dimensions to output.'),
] ]
if self.active_socket_set == 'By Dim Value': elif self.active_socket_set == 'Dimensions':
items += [ items += [
('FIX', 'del a | i≈v', 'Fix Coordinate'), ('PIN_LEN_ONE', 'pinₐ =1', 'Remove a len(1) dimension'),
(
'PIN',
'pinₐ ≈v',
'Remove a len(n) dimension by selecting an index',
),
('SWAP', 'a₁ ↔ a₂', 'Swap the position of two dimensions'),
] ]
return [(*item, '', i) for i, item in enumerate(items)] return [(*item, '', i) for i, item in enumerate(items)]
#################### ####################
# - Dim Search # - Dimensions Search
#################### ####################
def search_dims(self) -> list[ct.BLEnumElement]: def search_dims(self) -> list[ct.BLEnumElement]:
if self.data_info is not None: if self.data_info is None:
dims = [
(dim_name, dim_name, dim_name, '', i)
for i, dim_name in enumerate(self.data_info.dim_names)
]
# Squeeze: Dimension Must Have Length=1
## We must also correct the "NUMBER" of the enum.
if self.operation == 'SQUEEZE':
filtered_dims = [
dim for dim in dims if self.data_info.dim_lens[dim[0]] == 1
]
return [(*dim[:-1], i) for i, dim in enumerate(filtered_dims)]
return dims
return [] return []
if self.operation == 'PIN_LEN_ONE':
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}" of length 1')
for dim_name in self.data_info.dim_names
if self.data_info.dim_lens[dim_name] == 1
]
elif self.operation in ['PIN', 'SWAP']:
dims = [
(dim_name, dim_name, f'Dimension "{dim_name}"')
for dim_name in self.data_info.dim_names
]
else:
return []
return [(*dim, '', i) for i, dim in enumerate(dims)]
#################### ####################
# - UI # - UI
#################### ####################
def draw_label(self):
labels = {
'PIN_LEN_ONE': lambda: f'Filter: Pin {self.dim_0} (len=1)',
'PIN': lambda: f'Filter: Pin {self.dim_0}',
'SWAP': lambda: f'Filter: Swap {self.dim_0}|{self.dim_1}',
'DIM_TO_VEC': lambda: 'Filter: -> Vector',
'DIMS_TO_MAT': lambda: 'Filter: -> Matrix',
}
if (label := labels.get(self.operation)) is not None:
return label()
return self.bl_label
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
if self.data_info is not None and self.data_info.dim_names:
layout.prop(self, self.blfields['dim'], text='') if self.active_socket_set == 'Dimensions':
if self.operation in ['PIN_LEN_ONE', 'PIN']:
layout.prop(self, self.blfields['dim_0'], text='')
if self.operation == 'SWAP':
row = layout.row(align=True)
row.prop(self, self.blfields['dim_0'], text='')
row.prop(self, self.blfields['dim_1'], text='')
#################### ####################
# - Events # - Events
#################### ####################
@events.on_value_changed( @events.on_value_changed(
socket_name='Data',
prop_name='active_socket_set', prop_name='active_socket_set',
run_on_init=True, run_on_init=True,
input_sockets={'Data'},
) )
def on_any_change(self, input_sockets: dict): def on_socket_set_changed(self):
if all(
not ct.FlowSignal.check_single(
input_socket_value, ct.FlowSignal.FlowPending
)
for input_socket_value in input_sockets.values()
):
self.operation = bl_cache.Signal.ResetEnumItems self.operation = bl_cache.Signal.ResetEnumItems
self.dim = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
# Trigger
socket_name='Data',
prop_name={'active_socket_set', 'operation'},
run_on_init=True,
# Loaded
props={'operation'},
)
def on_any_change(self, props: dict) -> None:
self.dim_0 = bl_cache.Signal.ResetEnumItems
self.dim_1 = bl_cache.Signal.ResetEnumItems
@events.on_value_changed( @events.on_value_changed(
socket_name='Data', socket_name='Data',
prop_name='dim', prop_name={'dim_0', 'dim_1', 'operation'},
## run_on_init: Implicitly triggered. ## run_on_init: Implicitly triggered.
props={'active_socket_set', 'dim'}, props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Data': ct.FlowKind.Info},
) )
def on_dim_change(self, props: dict, input_sockets: dict): def on_dim_change(self, props: dict, input_sockets: dict):
if input_sockets['Data'] == ct.FlowSignal.FlowPending: has_data = not ct.FlowSignal.check(input_sockets['Data'])
if not has_data:
return return
# Add/Remove Input Socket "Value" # "Dimensions"|"PIN": Add/Remove Input Socket
if ( if props['operation'] == 'PIN' and props['dim_0'] != 'NONE':
not ct.Flowsignal.check(input_sockets['Data'])
and props['active_socket_set'] == 'By Dim Value'
and props['dim'] != 'NONE'
):
# Get Current and Wanted Socket Defs # Get Current and Wanted Socket Defs
current_bl_socket = self.loose_input_sockets.get('Value') current_bl_socket = self.loose_input_sockets.get('Value')
wanted_socket_def = sockets.SOCKET_DEFS[ wanted_socket_def = sockets.SOCKET_DEFS[
ct.unit_to_socket_type(input_sockets['Data'].dim_idx[props['dim']].unit) ct.unit_to_socket_type(
input_sockets['Data'].dim_idx[props['dim_0']].unit
)
] ]
# Determine Whether to Declare New Loose Input SOcket # Determine Whether to Declare New Loose Input SOcket
@ -151,7 +202,7 @@ class FilterMathNode(base.MaxwellSimNode):
): ):
self.loose_input_sockets = { self.loose_input_sockets = {
'Value': wanted_socket_def(), 'Value': wanted_socket_def(),
} ## TODO: Can we do the boilerplate in base.py? }
elif self.loose_input_sockets: elif self.loose_input_sockets:
self.loose_input_sockets = {} self.loose_input_sockets = {}
@ -161,40 +212,51 @@ class FilterMathNode(base.MaxwellSimNode):
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
kind=ct.FlowKind.LazyValueFunc, kind=ct.FlowKind.LazyValueFunc,
props={'active_socket_set', 'operation', 'dim'}, props={'operation', 'dim_0', 'dim_1'},
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}}, input_socket_kinds={'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info}},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_data(self, props: dict, input_sockets: dict):
# Retrieve Inputs
lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc] lazy_value_func = input_sockets['Data'][ct.FlowKind.LazyValueFunc]
info = input_sockets['Data'][ct.FlowKind.Info] info = input_sockets['Data'][ct.FlowKind.Info]
# Check Flow # Check Flow
if ( if any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func]):
any(ct.FlowSignal.check(inp) for inp in [info, lazy_value_func])
or props['operation'] == 'NONE'
):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Compute Bound/Free Parameters # Compute Function Arguments
func_args = [int] if props['active_socket_set'] == 'By Dim Value' else [] operation = props['operation']
axis = info.dim_names.index(props['dim']) if operation == 'NONE':
return ct.FlowSignal.FlowPending
# Select Function ## Dimension(s)
filter_func: typ.Callable[[jax.Array], jax.Array] = { dim_0 = props['dim_0']
'By Dim': {'SQUEEZE': lambda data: jnp.squeeze(data, axis)}, dim_1 = props['dim_1']
'By Dim Value': { if operation in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
'FIX': lambda data, fixed_axis_idx: jnp.take( return ct.FlowSignal.FlowPending
data, fixed_axis_idx, axis=axis if operation == 'SWAP' and dim_1 == 'NONE':
) return ct.FlowSignal.FlowPending
},
}[props['active_socket_set']][props['operation']] ## Axis/Axes
axis_0 = info.dim_names.index(dim_0) if dim_0 != 'NONE' else None
axis_1 = info.dim_names.index(dim_1) if dim_1 != 'NONE' else None
# Compose Output Function
filter_func = {
# Dimensions
'PIN_LEN_ONE': lambda data: jnp.squeeze(data, axis_0),
'PIN': lambda data, fixed_axis_idx: jnp.take(
data, fixed_axis_idx, axis=axis_0
),
'SWAP': lambda data: jnp.swapaxes(data, axis_0, axis_1),
# Interpret
'DIM_TO_VEC': lambda data: data,
'DIMS_TO_MAT': lambda data: data,
}[props['operation']]
# Compose Function for Output
return lazy_value_func.compose_within( return lazy_value_func.compose_within(
filter_func, filter_func,
enclosing_func_args=func_args, enclosing_func_args=[int] if operation == 'PIN' else [],
supports_jax=True, supports_jax=True,
) )
@ -207,7 +269,6 @@ class FilterMathNode(base.MaxwellSimNode):
}, },
) )
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow: def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
# Retrieve Inputs
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc] lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params] params = output_sockets['Data'][ct.FlowKind.Params]
@ -215,57 +276,54 @@ class FilterMathNode(base.MaxwellSimNode):
if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]): if any(ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Compute Array
return ct.ArrayFlow( return ct.ArrayFlow(
values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs), values=lazy_value_func.func_jax(*params.func_args, **params.func_kwargs),
unit=None, ## TODO: Unit Propagation unit=None,
) )
#################### ####################
# - Compute Auxiliary: Info / Params # - Compute Auxiliary: Info
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
kind=ct.FlowKind.Info, kind=ct.FlowKind.Info,
props={'active_socket_set', 'dim', 'operation'}, props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data'}, input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info}, input_socket_kinds={'Data': ct.FlowKind.Info},
) )
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
# Retrieve Inputs
info = input_sockets['Data'] info = input_sockets['Data']
# Check Flow # Check Flow
if ct.FlowSignal.check(info) or props['dim'] == 'NONE': if ct.FlowSignal.check(info):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Compute Information # Collect Information
## Compute Info w/By-Operation Change to Dimensions dim_0 = props['dim_0']
axis = info.dim_names.index(props['dim']) dim_1 = props['dim_1']
if (props['active_socket_set'], props['operation']) in [ if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
('By Dim', 'SQUEEZE'), return ct.FlowSignal.FlowPending
('By Dim Value', 'FIX'), if props['operation'] == 'SWAP' and dim_1 == 'NONE':
] and info.dim_names: return ct.FlowSignal.FlowPending
return ct.InfoFlow(
dim_names=info.dim_names[:axis] + info.dim_names[axis + 1 :],
dim_idx={
dim_name: dim_idx
for dim_name, dim_idx in info.dim_idx.items()
if dim_name != props['dim']
},
output_names=info.output_names,
output_mathtypes=info.output_mathtypes,
output_units=info.output_units,
)
msg = f'Active socket set {props["active_socket_set"]} and operation {props["operation"]} don\'t have an InfoFlow defined' return {
raise RuntimeError(msg) # Dimensions
'PIN_LEN_ONE': lambda: info.delete_dimension(dim_0),
'PIN': lambda: info.delete_dimension(dim_0),
'SWAP': lambda: info.swap_dimensions(dim_0, dim_1),
# Interpret
'DIM_TO_VEC': lambda: info.shift_last_input,
'DIMS_TO_MAT': lambda: info.shift_last_input.shift_last_input,
}[props['operation']]()
####################
# - Compute Auxiliary: Info
####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
kind=ct.FlowKind.Params, kind=ct.FlowKind.Params,
props={'active_socket_set', 'dim', 'operation'}, props={'dim_0', 'dim_1', 'operation'},
input_sockets={'Data', 'Value'}, input_sockets={'Data', 'Value'},
input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}}, input_socket_kinds={'Data': {ct.FlowKind.Info, ct.FlowKind.Params}},
input_sockets_optional={'Value': True}, input_sockets_optional={'Value': True},
@ -273,35 +331,33 @@ class FilterMathNode(base.MaxwellSimNode):
def compute_composed_params( def compute_composed_params(
self, props: dict, input_sockets: dict self, props: dict, input_sockets: dict
) -> ct.ParamsFlow: ) -> ct.ParamsFlow:
# Retrieve Inputs
info = input_sockets['Data'][ct.FlowKind.Info] info = input_sockets['Data'][ct.FlowKind.Info]
params = input_sockets['Data'][ct.FlowKind.Params] params = input_sockets['Data'][ct.FlowKind.Params]
# Check Flow
if any(ct.FlowSignal.check(inp) for inp in [info, params]): if any(ct.FlowSignal.check(inp) for inp in [info, params]):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Compute Composed Parameters # Collect Information
## -> Only operations that add parameters. ## Dimensions
## -> A dimension must be selected. dim_0 = props['dim_0']
## -> There must be an input value. dim_1 = props['dim_1']
if (
(props['active_socket_set'], props['operation']) if props['operation'] in ['PIN_LEN_ONE', 'PIN', 'SWAP'] and dim_0 == 'NONE':
in [ return ct.FlowSignal.FlowPending
('By Dim Value', 'FIX'), if props['operation'] == 'SWAP' and dim_1 == 'NONE':
] return ct.FlowSignal.FlowPending
and props['dim'] != 'NONE'
and not ct.FlowSignal.check(input_sockets['Value']) ## Pinned Value
): pinned_value = input_sockets['Value']
# Compute IDX Corresponding to Coordinate Value has_pinned_value = not ct.FlowSignal.check(pinned_value)
## -> Each dimension declares a unit-aware real number at each index.
## -> "Value" is a unit-aware real number from loose input socket. if props['operation'] == 'PIN' and has_pinned_value:
## -> This finds the dimensional index closest to "Value". # Compute IDX Corresponding to Dimension Index
## Total Effect: Indexing by a unit-aware real number. nearest_idx_to_value = info.dim_idx[dim_0].nearest_idx_of(
nearest_idx_to_value = info.dim_idx[props['dim']].nearest_idx_of(
input_sockets['Value'], require_sorted=True input_sockets['Value'], require_sorted=True
) )
# Compose Parameters
return params.compose_within(enclosing_func_args=[nearest_idx_to_value]) return params.compose_within(enclosing_func_args=[nearest_idx_to_value])
return params return params

View File

@ -138,6 +138,7 @@ class MapMathNode(base.MaxwellSimNode):
('SQ', '', 'v^2 (by el)'), ('SQ', '', 'v^2 (by el)'),
('SQRT', '√v', 'sqrt(v) (by el)'), ('SQRT', '√v', 'sqrt(v) (by el)'),
('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'), ('INV_SQRT', '1/√v', '1/sqrt(v) (by el)'),
None,
# Trigonometry # Trigonometry
('COS', 'cos v', 'cos(v) (by el)'), ('COS', 'cos v', 'cos(v) (by el)'),
('SIN', 'sin v', 'sin(v) (by el)'), ('SIN', 'sin v', 'sin(v) (by el)'),
@ -148,6 +149,7 @@ class MapMathNode(base.MaxwellSimNode):
] ]
elif self.active_socket_set in 'By Vector': elif self.active_socket_set in 'By Vector':
items = [ items = [
# Vector -> Number
('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'), ('NORM_2', '||v||₂', 'norm(v, 2) (by Vec)'),
] ]
elif self.active_socket_set == 'By Matrix': elif self.active_socket_set == 'By Matrix':
@ -157,13 +159,16 @@ class MapMathNode(base.MaxwellSimNode):
('COND', 'κ(V)', 'cond(V) (by Mat)'), ('COND', 'κ(V)', 'cond(V) (by Mat)'),
('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'), ('NORM_FRO', '||V||_F', 'norm(V, frobenius) (by Mat)'),
('RANK', 'rank V', 'rank(V) (by Mat)'), ('RANK', 'rank V', 'rank(V) (by Mat)'),
None,
# Matrix -> Array # Matrix -> Array
('DIAG', 'diag V', 'diag(V) (by Mat)'), ('DIAG', 'diag V', 'diag(V) (by Mat)'),
('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'), ('EIG_VALS', 'eigvals V', 'eigvals(V) (by Mat)'),
('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'), ('SVD_VALS', 'svdvals V', 'diag(svd(V)) (by Mat)'),
None,
# Matrix -> Matrix # Matrix -> Matrix
('INV', 'V⁻¹', 'V^(-1) (by Mat)'), ('INV', 'V⁻¹', 'V^(-1) (by Mat)'),
('TRA', 'Vt', 'V^T (by Mat)'), ('TRA', 'Vt', 'V^T (by Mat)'),
None,
# Matrix -> Matrices # Matrix -> Matrices
('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'), ('QR', 'qr V', 'qr(V) -> Q·R (by Mat)'),
('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'), ('CHOL', 'chol V', 'cholesky(V) -> V·V† (by Mat)'),
@ -175,7 +180,9 @@ class MapMathNode(base.MaxwellSimNode):
msg = f'Active socket set {self.active_socket_set} is unknown' msg = f'Active socket set {self.active_socket_set} is unknown'
raise RuntimeError(msg) raise RuntimeError(msg)
return [(*item, '', i) for i, item in enumerate(items)] return [
(*item, '', i) if item is not None else None for i, item in enumerate(items)
]
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='') layout.prop(self, self.blfields['operation'], text='')
@ -185,8 +192,9 @@ class MapMathNode(base.MaxwellSimNode):
#################### ####################
@events.on_value_changed( @events.on_value_changed(
prop_name='active_socket_set', prop_name='active_socket_set',
run_on_init=True,
) )
def on_operation_changed(self): def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems self.operation = bl_cache.Signal.ResetEnumItems
#################### ####################
@ -204,11 +212,14 @@ class MapMathNode(base.MaxwellSimNode):
input_sockets_optional={'Mapper': True}, input_sockets_optional={'Mapper': True},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_data(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data'])
if ( if (
ct.FlowSignal.check(input_sockets['Data']) or props['operation'] == 'NONE' not has_data
) or ( or props['operation'] == 'NONE'
or (
props['active_socket_set'] == 'Expr' props['active_socket_set'] == 'Expr'
and ct.FlowSignal.check(input_sockets['Mapper']) and ct.FlowSignal.check(input_sockets['Mapper'])
)
): ):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
@ -238,14 +249,14 @@ class MapMathNode(base.MaxwellSimNode):
'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'), 'NORM_FRO': lambda data: jnp.linalg.matrix_norm(data, ord='fro'),
'RANK': lambda data: jnp.linalg.matrix_rank(data), 'RANK': lambda data: jnp.linalg.matrix_rank(data),
# Matrix -> Vec # Matrix -> Vec
'DIAG': lambda data: jnp.diag(data), 'DIAG': lambda data: jnp.diagonal(data, axis1=-2, axis2=-1),
'EIG_VALS': lambda data: jnp.eigvals(data), 'EIG_VALS': lambda data: jnp.linalg.eigvals(data),
'SVD_VALS': lambda data: jnp.svdvals(data), 'SVD_VALS': lambda data: jnp.linalg.svdvals(data),
# Matrix -> Matrix # Matrix -> Matrix
'INV': lambda data: jnp.inv(data), 'INV': lambda data: jnp.linalg.inv(data),
'TRA': lambda data: jnp.matrix_transpose(data), 'TRA': lambda data: jnp.matrix_transpose(data),
# Matrix -> Matrices # Matrix -> Matrices
'QR': lambda data: jnp.inv(data), 'QR': lambda data: jnp.linalg.qr(data),
'CHOL': lambda data: jnp.linalg.cholesky(data), 'CHOL': lambda data: jnp.linalg.cholesky(data),
'SVD': lambda data: jnp.linalg.svd(data), 'SVD': lambda data: jnp.linalg.svd(data),
}, },
@ -298,28 +309,53 @@ class MapMathNode(base.MaxwellSimNode):
return ct.FlowSignal.FlowPending return ct.FlowSignal.FlowPending
# Complex -> Real # Complex -> Real
if props['active_socket_set'] == 'By Element': if props['active_socket_set'] == 'By Element' and props['operation'] in [
if props['operation'] in [
'REAL', 'REAL',
'IMAG', 'IMAG',
'ABS', 'ABS',
]: ]:
return ct.InfoFlow( return info.set_output_mathtype(spux.MathType.Real)
dim_names=info.dim_names,
dim_idx=info.dim_idx, if props['active_socket_set'] == 'By Vector' and props['operation'] in [
output_names=info.output_names, 'NORM_2'
output_mathtypes={ ]:
output_name: ( return {
spux.MathType.Real 'NORM_2': lambda: info.collapse_output(
if output_mathtype == spux.MathType.Complex collapsed_name=f'||{info.output_name}||₂',
else output_mathtype collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
) )
for output_name, output_mathtype in info.output_mathtypes.items() }[props['operation']]()
},
output_units=info.output_units, if props['active_socket_set'] == 'By Matrix' and props['operation'] in [
) 'DET',
if props['active_socket_set'] == 'By Vector': 'COND',
pass 'NORM_FRO',
'RANK',
]:
return {
'DET': lambda: info.collapse_output(
collapsed_name=f'det {info.output_name}',
collapsed_mathtype=info.output_mathtype,
collapsed_unit=info.output_unit,
),
'COND': lambda: info.collapse_output(
collapsed_name=f'κ({info.output_name})',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=None,
),
'NORM_FRO': lambda: info.collapse_output(
collapsed_name=f'||({info.output_name}||_F',
collapsed_mathtype=spux.MathType.Real,
collapsed_unit=info.output_unit,
),
'RANK': lambda: info.collapse_output(
collapsed_name=f'rank {info.output_name}',
collapsed_mathtype=spux.MathType.Integer,
collapsed_unit=None,
),
}[props['operation']]()
return info return info
@events.computes_output_socket( @events.computes_output_socket(

View File

@ -1,9 +1,12 @@
import enum
import typing as typ import typing as typ
import bpy import bpy
import jax.numpy as jnp import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import logger from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct from .... import contracts as ct
from .... import sockets from .... import sockets
@ -13,120 +16,465 @@ log = logger.get(__name__)
class OperateMathNode(base.MaxwellSimNode): class OperateMathNode(base.MaxwellSimNode):
r"""Applies a function that depends on two inputs.
Attributes:
category: The category of operations to apply to the inputs.
**Only valid** categories can be chosen.
operation: The actual operation to apply to the inputs.
**Only valid** operations can be chosen.
"""
node_type = ct.NodeType.OperateMath node_type = ct.NodeType.OperateMath
bl_label = 'Operate Math' bl_label = 'Operate Math'
input_socket_sets: typ.ClassVar = { input_socket_sets: typ.ClassVar = {
'Elementwise': { 'Expr | Expr': {
'Data L': sockets.AnySocketDef(), 'Expr L': sockets.ExprSocketDef(),
'Data R': sockets.AnySocketDef(), 'Expr R': sockets.ExprSocketDef(),
}, },
## TODO: Filter-array building operations 'Data | Data': {
'Vec-Vec': { 'Data L': sockets.DataSocketDef(
'Data L': sockets.AnySocketDef(), format='jax', default_show_info_columns=False
'Data R': sockets.AnySocketDef(), ),
'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
}, },
'Mat-Vec': { 'Expr | Data': {
'Data L': sockets.AnySocketDef(), 'Expr L': sockets.ExprSocketDef(),
'Data R': sockets.AnySocketDef(), 'Data R': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
}, },
} }
output_sockets: typ.ClassVar = { output_socket_sets: typ.ClassVar = {
'Data': sockets.AnySocketDef(), 'Expr | Expr': {
'Expr': sockets.ExprSocketDef(),
},
'Data | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
'Expr | Data': {
'Data': sockets.DataSocketDef(
format='jax', default_show_info_columns=False
),
},
} }
#################### ####################
# - Properties # - Properties
#################### ####################
operation: bpy.props.EnumProperty( category: enum.Enum = bl_cache.BLField(
name='Op', prop_ui=True, enum_cb=lambda self, _: self.search_categories()
description='Operation to apply to the two inputs',
items=lambda self, _: self.search_operations(),
update=lambda self, context: self.on_prop_changed('operation', context),
) )
def search_operations(self) -> list[tuple[str, str, str]]: operation: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
)
def search_categories(self) -> list[ct.BLEnumElement]:
"""Deduce and return a list of valid categories for the current socket set and input data."""
data_l_info = self._compute_input(
'Data L', kind=ct.FlowKind.Info, optional=True
)
data_r_info = self._compute_input(
'Data R', kind=ct.FlowKind.Info, optional=True
)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
# Categories by Socket Set
NUMBER_NUMBER = (
'Number | Number',
'Number | Number',
'Operations between numerical elements',
)
NUMBER_VECTOR = (
'Number | Vector',
'Number | Vector',
'Operations between numerical and vector elements',
)
NUMBER_MATRIX = (
'Number | Matrix',
'Number | Matrix',
'Operations between numerical and matrix elements',
)
VECTOR_VECTOR = (
'Vector | Vector',
'Vector | Vector',
'Operations between vector elements',
)
MATRIX_VECTOR = (
'Matrix | Vector',
'Matrix | Vector',
'Operations between vector and matrix elements',
)
MATRIX_MATRIX = (
'Matrix | Matrix',
'Matrix | Matrix',
'Operations between matrix elements',
)
categories = []
## Expr | Expr
if self.active_socket_set == 'Expr | Expr':
return [NUMBER_NUMBER]
## Data | Data
if (
self.active_socket_set == 'Data | Data'
and has_data_l_info
and has_data_r_info
):
# Check Valid Broadcasting
## Number | Number
if data_l_info.output_shape is None and data_r_info.output_shape is None:
categories = [NUMBER_NUMBER]
## Number | Vector
elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 1
):
categories = [NUMBER_VECTOR]
## Number | Matrix
elif (
data_l_info.output_shape is None and len(data_r_info.output_shape) == 2
): # noqa: PLR2004
categories = [NUMBER_MATRIX]
## Vector | Vector
elif (
len(data_l_info.output_shape) == 1
and len(data_r_info.output_shape) == 1
):
categories = [VECTOR_VECTOR]
## Matrix | Vector
elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 1
):
categories = [MATRIX_VECTOR]
## Matrix | Matrix
elif (
len(data_l_info.output_shape) == 2 # noqa: PLR2004
and len(data_r_info.output_shape) == 2 # noqa: PLR2004
):
categories = [MATRIX_MATRIX]
## Expr | Data
if self.active_socket_set == 'Expr | Data' and has_data_r_info:
if data_r_info.output_shape is None:
categories = [NUMBER_NUMBER]
else:
categories = {
1: [NUMBER_NUMBER, NUMBER_VECTOR],
2: [NUMBER_NUMBER, NUMBER_MATRIX],
}[len(data_r_info.output_shape)]
return [
(*category, '', i) if category is not None else None
for i, category in enumerate(categories)
]
def search_operations(self) -> list[ct.BLEnumElement]:
items = [] items = []
if self.active_socket_set == 'Elementwise': if self.category in ['Number | Number', 'Number | Vector', 'Number | Matrix']:
items = [ items += [
('ADD', 'Add', 'L + R (by el)'), ('ADD', 'L + R', 'Add'),
('SUB', 'Subtract', 'L - R (by el)'), ('SUB', 'L - R', 'Subtract'),
('MUL', 'Multiply', 'L · R (by el)'), ('MUL', 'L · R', 'Multiply'),
('DIV', 'Divide', 'L ÷ R (by el)'), ('DIV', 'L ÷ R', 'Divide'),
('POW', 'Power', 'L^R (by el)'), ('POW', 'L^R', 'Power'),
('FMOD', 'Trunc Modulo', 'fmod(L,R) (by el)'), ('ATAN2', 'atan2(L,R)', 'atan2(L,R)'),
('ATAN2', 'atan2', 'atan2(L,R) (by el)'),
('HEAVISIDE', 'Heaviside', '{0|L<0 1|L>0 R|L=0} (by el)'),
] ]
elif self.active_socket_set in 'Vec | Vec': if self.category in 'Vector | Vector':
items = [ if items:
('DOT', 'Dot', 'L · R'), items += [None]
('CROSS', 'Cross', 'L x R (by last-axis'), items += [
('VEC_VEC_DOT', 'L · R', 'Vector-Vector Product'),
('CROSS', 'L x R', 'Cross Product'),
('PROJ', 'proj(L, R)', 'Projection'),
] ]
elif self.active_socket_set == 'Mat | Vec': if self.category == 'Matrix | Vector':
items = [ if items:
('DOT', 'Dot', 'L · R'), items += [None]
('LIN_SOLVE', 'Lin Solve', 'Lx = R -> x (by last-axis of R)'), items += [
('LSQ_SOLVE', 'LSq Solve', 'Lx = R ~> x (by last-axis of R)'), ('MAT_VEC_DOT', 'L · R', 'Matrix-Vector Product'),
('LIN_SOLVE', 'Lx = R -> x', 'Linear Solve'),
('LSQ_SOLVE', 'Lx = R ~> x', 'Least Squares Solve'),
]
if self.category == 'Matrix | Matrix':
if items:
items += [None]
items += [
('MAT_MAT_DOT', 'L · R', 'Matrix-Matrix Product'),
]
return [
(*item, '', i) if item is not None else None for i, item in enumerate(items)
] ]
return items
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, 'operation') layout.prop(self, self.blfields['category'], text='')
layout.prop(self, self.blfields['operation'], text='')
#################### ####################
# - Properties # - Events
####################
@events.on_value_changed(
# Trigger
socket_name={'Expr L', 'Expr R', 'Data L', 'Data R'},
prop_name='active_socket_set',
run_on_init=True,
)
def on_socket_set_changed(self) -> None:
# Recompute Valid Categories
self.category = bl_cache.Signal.ResetEnumItems
self.operation = bl_cache.Signal.ResetEnumItems
@events.on_value_changed(
prop_name='category',
run_on_init=True,
)
def on_category_changed(self) -> None:
self.operation = bl_cache.Signal.ResetEnumItems
####################
# - Output
####################
@events.computes_output_socket(
'Expr',
kind=ct.FlowKind.Value,
props={'operation'},
input_sockets={'Expr L', 'Expr R'},
)
def compute_expr(self, props: dict, input_sockets: dict):
expr_l = input_sockets['Expr L']
expr_r = input_sockets['Expr R']
return {
'ADD': lambda: expr_l + expr_r,
'SUB': lambda: expr_l - expr_r,
'MUL': lambda: expr_l * expr_r,
'DIV': lambda: expr_l / expr_r,
'POW': lambda: expr_l**expr_r,
'ATAN2': lambda: sp.atan2(expr_r, expr_l),
}[props['operation']]()
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
props={'operation'},
input_sockets={'Data L', 'Data R'},
input_socket_kinds={
'Data L': ct.FlowKind.LazyValueFunc,
'Data R': ct.FlowKind.LazyValueFunc,
},
input_sockets_optional={
'Data L': True,
'Data R': True,
},
)
def compute_data(self, props: dict, input_sockets: dict):
data_l = input_sockets['Data L']
data_r = input_sockets['Data R']
has_data_l = not ct.FlowSignal.check(data_l)
mapping_func = {
# Number | *
'ADD': lambda datas: datas[0] + datas[1],
'SUB': lambda datas: datas[0] - datas[1],
'MUL': lambda datas: datas[0] * datas[1],
'DIV': lambda datas: datas[0] / datas[1],
'POW': lambda datas: datas[0] ** datas[1],
'ATAN2': lambda datas: jnp.atan2(datas[1], datas[0]),
# Vector | Vector
'VEC_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'CROSS': lambda datas: jnp.cross(datas[0], datas[1]),
# Matrix | Vector
'MAT_VEC_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
'LIN_SOLVE': lambda datas: jnp.linalg.solve(datas[0], datas[1]),
'LSQ_SOLVE': lambda datas: jnp.linalg.lstsq(datas[0], datas[1]),
# Matrix | Matrix
'MAT_MAT_DOT': lambda datas: jnp.matmul(datas[0], datas[1]),
}[props['operation']]
# Compose by Socket Set
## Data | Data
if has_data_l:
return (data_l | data_r).compose_within(
mapping_func,
supports_jax=True,
)
## Expr | Data
expr_l_lazy_value_func = ct.LazyValueFuncFlow(
func=lambda expr_l_value: expr_l_value,
func_args=[typ.Any],
supports_jax=True,
)
return (expr_l_lazy_value_func | data_r).compose_within(
mapping_func,
supports_jax=True,
)
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
output_sockets={'Data'},
output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
has_lazy_value_func = not ct.FlowSignal.check(lazy_value_func)
has_params = not ct.FlowSignal.check(params)
if has_lazy_value_func and has_params:
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs
),
unit=None,
)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Params
#################### ####################
@events.computes_output_socket( @events.computes_output_socket(
'Data', 'Data',
kind=ct.FlowKind.Params,
props={'operation'}, props={'operation'},
input_sockets={'Data L', 'Data R'}, input_sockets={'Expr L', 'Data L', 'Data R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Data L': {ct.FlowKind.Info, ct.FlowKind.Params},
'Data R': {ct.FlowKind.Info, ct.FlowKind.Params},
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
},
) )
def compute_data(self, props: dict, input_sockets: dict): def compute_data_params(
if self.active_socket_set == 'Elementwise': self, props, input_sockets
# Element-Wise Arithmetic ) -> ct.ParamsFlow | ct.FlowSignal:
if props['operation'] == 'ADD': expr_l = input_sockets['Expr L']
return input_sockets['Data L'] + input_sockets['Data R'] data_l_info = input_sockets['Data L'][ct.FlowKind.Info]
if props['operation'] == 'SUB': data_l_params = input_sockets['Data L'][ct.FlowKind.Params]
return input_sockets['Data L'] - input_sockets['Data R'] data_r_info = input_sockets['Data R'][ct.FlowKind.Info]
if props['operation'] == 'MUL': data_r_params = input_sockets['Data R'][ct.FlowKind.Params]
return input_sockets['Data L'] * input_sockets['Data R']
if props['operation'] == 'DIV':
return input_sockets['Data L'] / input_sockets['Data R']
# Element-Wise Arithmetic has_expr_l = not ct.FlowSignal.check(expr_l)
if props['operation'] == 'POW': has_data_l_info = not ct.FlowSignal.check(data_l_info)
return input_sockets['Data L'] ** input_sockets['Data R'] has_data_l_params = not ct.FlowSignal.check(data_l_params)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
has_data_r_params = not ct.FlowSignal.check(data_r_params)
# Binary Trigonometry #log.critical((props, input_sockets))
if props['operation'] == 'ATAN2':
return jnp.atan2(input_sockets['Data L'], input_sockets['Data R'])
# Special Functions # Compose by Socket Set
if props['operation'] == 'HEAVISIDE': ## Data | Data
return jnp.heaviside(input_sockets['Data L'], input_sockets['Data R']) if (
has_data_l_info
and has_data_l_params
and has_data_r_info
and has_data_r_params
):
return data_l_params | data_r_params
# Linear Algebra ## Expr | Data
if self.active_socket_set in {'Vec-Vec', 'Mat-Vec'}: if has_expr_l and has_data_r_info and has_data_r_params:
if props['operation'] == 'DOT': operation = props['operation']
return jnp.dot(input_sockets['Data L'], input_sockets['Data R']) data_unit = data_r_info.output_unit
elif self.active_socket_set == 'Vec-Vec': # By Operation
if props['operation'] == 'CROSS': ## Add/Sub: Scale to Output Unit
return jnp.cross(input_sockets['Data L'], input_sockets['Data R']) if operation in ['ADD', 'SUB', 'MUL', 'DIV']:
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
elif self.active_socket_set == 'Mat-Vec': return data_r_params.compose_within(
if props['operation'] == 'LIN_SOLVE': enclosing_func_args=[value],
return jnp.linalg.lstsq(
input_sockets['Data L'], input_sockets['Data R']
)
if props['operation'] == 'LSQ_SOLVE':
return jnp.linalg.solve(
input_sockets['Data L'], input_sockets['Data R']
) )
msg = 'Invalid operation' ## Pow: Doesn't Exist (?)
raise ValueError(msg) ## -> See https://math.stackexchange.com/questions/4326081/units-of-the-exponential-function
if operation == 'POW':
return ct.FlowSignal.FlowPending
## atan2(): Only Length
## -> Implicitly presume that Data L/R use length units.
if operation == 'ATAN2':
if not spux.uses_units(expr_l):
value = spux.sympy_to_python(expr_l)
else:
value = spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
return data_r_params.compose_within(
enclosing_func_args=[value],
)
return data_r_params.compose_within(
enclosing_func_args=[
spux.sympy_to_python(spux.scale_to_unit(expr_l, data_unit))
]
)
return ct.FlowSignal.FlowPending
####################
# - Auxiliary: Info
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
input_sockets={'Expr L', 'Data L', 'Data R'},
input_socket_kinds={
'Expr L': ct.FlowKind.Value,
'Data L': ct.FlowKind.Info,
'Data R': ct.FlowKind.Info,
},
input_sockets_optional={
'Expr L': True,
'Data L': True,
'Data R': True,
},
)
def compute_data_info(self, input_sockets: dict) -> ct.InfoFlow:
expr_l = input_sockets['Expr L']
data_l_info = input_sockets['Data L']
data_r_info = input_sockets['Data R']
has_expr_l = not ct.FlowSignal.check(expr_l)
has_data_l_info = not ct.FlowSignal.check(data_l_info)
has_data_r_info = not ct.FlowSignal.check(data_r_info)
# Info by Socket Set
## Data | Data
if has_data_l_info and has_data_r_info:
return data_r_info
## Expr | Data
if has_expr_l and has_data_r_info:
return data_r_info
return ct.FlowSignal.FlowPending
#################### ####################

View File

@ -0,0 +1,165 @@
"""Declares `TransformMathNode`."""
import enum
import typing as typ
import bpy
import jax
import jax.numpy as jnp
import sympy as sp
from blender_maxwell.utils import bl_cache, logger
from blender_maxwell.utils import extra_sympy_units as spux
from .... import contracts as ct
from .... import sockets
from ... import base, events
log = logger.get(__name__)
class TransformMathNode(base.MaxwellSimNode):
r"""Applies a function to the array as a whole, with arbitrary results.
The shape, type, and interpretation of the input/output data is dynamically shown.
# Socket Sets
## Interpret
Reinterprets the `InfoFlow` of an array, **without changing it**.
Attributes:
operation: Operation to apply to the input.
"""
node_type = ct.NodeType.TransformMath
bl_label = 'Transform Math'
input_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'),
}
input_socket_sets: typ.ClassVar = {
'Fourier': {},
'Affine': {},
'Convolve': {},
}
output_sockets: typ.ClassVar = {
'Data': sockets.DataSocketDef(format='jax'),
}
####################
# - Properties
####################
operation: enum.Enum = bl_cache.BLField(
prop_ui=True, enum_cb=lambda self, _: self.search_operations()
)
def search_operations(self) -> list[ct.BLEnumElement]:
if self.active_socket_set == 'Fourier': # noqa: SIM114
items = []
elif self.active_socket_set == 'Affine': # noqa: SIM114
items = []
elif self.active_socket_set == 'Convolve':
items = []
else:
msg = f'Active socket set {self.active_socket_set} is unknown'
raise RuntimeError(msg)
return [(*item, '', i) for i, item in enumerate(items)]
def draw_props(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None:
layout.prop(self, self.blfields['operation'], text='')
####################
# - Events
####################
@events.on_value_changed(
prop_name='active_socket_set',
)
def on_socket_set_changed(self):
self.operation = bl_cache.Signal.ResetEnumItems
####################
# - Compute: LazyValueFunc / Array
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.LazyValueFunc,
props={'active_socket_set', 'operation'},
input_sockets={'Data'},
input_socket_kinds={
'Data': ct.FlowKind.LazyValueFunc,
},
)
def compute_data(self, props: dict, input_sockets: dict):
has_data = not ct.FlowSignal.check(input_sockets['Data'])
if not has_data or props['operation'] == 'NONE':
return ct.FlowSignal.FlowPending
mapping_func: typ.Callable[[jax.Array], jax.Array] = {
'Fourier': {},
'Affine': {},
'Convolve': {},
}[props['active_socket_set']][props['operation']]
# Compose w/Lazy Root Function Data
return input_sockets['Data'].compose_within(
mapping_func,
supports_jax=True,
)
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Array,
output_sockets={'Data'},
output_socket_kinds={
'Data': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Params},
},
)
def compute_array(self, output_sockets: dict) -> ct.ArrayFlow:
lazy_value_func = output_sockets['Data'][ct.FlowKind.LazyValueFunc]
params = output_sockets['Data'][ct.FlowKind.Params]
if all(not ct.FlowSignal.check(inp) for inp in [lazy_value_func, params]):
return ct.ArrayFlow(
values=lazy_value_func.func_jax(
*params.func_args, **params.func_kwargs
),
unit=None,
)
return ct.FlowSignal.FlowPending
####################
# - Compute Auxiliary: Info / Params
####################
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Info,
props={'active_socket_set', 'operation'},
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Info},
)
def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow:
info = input_sockets['Data']
if ct.FlowSignal.check(info):
return ct.FlowSignal.FlowPending
return info
@events.computes_output_socket(
'Data',
kind=ct.FlowKind.Params,
input_sockets={'Data'},
input_socket_kinds={'Data': ct.FlowKind.Params},
)
def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal:
return input_sockets['Data']
####################
# - Blender Registration
####################
BL_REGISTER = [
TransformMathNode,
]
BL_NODES = {ct.NodeType.TransformMath: (ct.NodeCategory.MAXWELLSIM_ANALYSIS_MATH)}

View File

@ -61,32 +61,34 @@ class VizMode(enum.StrEnum):
@staticmethod @staticmethod
def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None: def valid_modes_for(info: ct.InfoFlow) -> list[typ.Self] | None:
EMPTY = ()
Z = spux.MathType.Integer
R = spux.MathType.Real
VM = VizMode
valid_viz_modes = { valid_viz_modes = {
((), (spux.MathType.Real,)): [VizMode.Hist1D, VizMode.BoxPlot1D], (EMPTY, (None, R)): [VM.Hist1D, VM.BoxPlot1D],
((spux.MathType.Integer), (spux.MathType.Real)): [ ((Z), (None, R)): [
VizMode.Hist1D, VM.Hist1D,
VizMode.BoxPlot1D, VM.BoxPlot1D,
], ],
((spux.MathType.Real,), (spux.MathType.Real,)): [ ((R,), (None, R)): [
VizMode.Curve2D, VM.Curve2D,
VizMode.Points2D, VM.Points2D,
VizMode.Bar, VM.Bar,
], ],
((spux.MathType.Real, spux.MathType.Integer), (spux.MathType.Real,)): [ ((R, Z), (None, R)): [
VizMode.Curves2D, VM.Curves2D,
VizMode.FilledCurves2D, VM.FilledCurves2D,
], ],
((spux.MathType.Real, spux.MathType.Real), (spux.MathType.Real,)): [ ((R, R), (None, R)): [
VizMode.Heatmap2D, VM.Heatmap2D,
], ],
( ((R, R, R), (None, R)): [VM.SqueezedHeatmap2D, VM.Heatmap3D],
(spux.MathType.Real, spux.MathType.Real, spux.MathType.Real),
(spux.MathType.Real,),
): [VizMode.SqueezedHeatmap2D, VizMode.Heatmap3D],
}.get( }.get(
( (
tuple(info.dim_mathtypes.values()), tuple(info.dim_mathtypes.values()),
tuple(info.output_mathtypes.values()), (info.output_shape, info.output_mathtype),
) )
) )
@ -161,10 +163,10 @@ class VizTarget(enum.StrEnum):
@staticmethod @staticmethod
def to_name(value: typ.Self) -> str: def to_name(value: typ.Self) -> str:
return { return {
VizTarget.Plot2D: 'Image (Plot)', VizTarget.Plot2D: 'Plot',
VizTarget.Pixels: 'Image (Pixels)', VizTarget.Pixels: 'Pixels',
VizTarget.PixelsPlane: 'Image (Plane)', VizTarget.PixelsPlane: 'Image Plane',
VizTarget.Voxels: '3D Field', VizTarget.Voxels: 'Voxels',
}[value] }[value]
@staticmethod @staticmethod

View File

@ -1036,6 +1036,12 @@ class MaxwellSimNode(bpy.types.Node):
# Generate New Instance ID # Generate New Instance ID
self.reset_instance_id() self.reset_instance_id()
# Generate New Instance ID for Sockets
## Sockets can't do this themselves.
for bl_sockets in [self.inputs, self.outputs]:
for bl_socket in bl_sockets:
bl_socket.reset_instance_id()
# Generate New Sim Node Name # Generate New Sim Node Name
## Blender will automatically add .001 so that `self.name` is unique. ## Blender will automatically add .001 so that `self.name` is unique.
self.sim_node_name = self.name self.sim_node_name = self.name

View File

@ -14,7 +14,7 @@ class ScientificConstantNode(base.MaxwellSimNode):
bl_label = 'Scientific Constant' bl_label = 'Scientific Constant'
output_sockets: typ.ClassVar = { output_sockets: typ.ClassVar = {
'Value': sockets.AnySocketDef(), 'Value': sockets.ExprSocketDef(),
} }
#################### ####################

View File

@ -913,10 +913,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket):
col = layout.column() col = layout.column()
row = col.row() row = col.row()
row.alignment = 'RIGHT' row.alignment = 'RIGHT'
if self.is_linked:
self.draw_output_label_row(row, text) self.draw_output_label_row(row, text)
else:
row.label(text=text)
# Draw FlowKind.Info related Information # Draw FlowKind.Info related Information
if self.use_info_draw: if self.use_info_draw:

View File

@ -12,6 +12,10 @@ from .. import base
log = logger.get(__name__) log = logger.get(__name__)
def unicode_superscript(n):
return ''.join(['⁰¹²³⁴⁵⁶⁷⁸⁹'[ord(c) - ord('0')] for c in str(n)])
class DataInfoColumn(enum.StrEnum): class DataInfoColumn(enum.StrEnum):
Length = enum.auto() Length = enum.auto()
MathType = enum.auto() MathType = enum.auto()
@ -49,11 +53,11 @@ class DataBLSocket(base.MaxwellSimSocket):
## TODO: typ.Literal['xarray', 'jax'] ## TODO: typ.Literal['xarray', 'jax']
show_info_columns: bool = bl_cache.BLField( show_info_columns: bool = bl_cache.BLField(
False, True,
prop_ui=True, prop_ui=True,
) )
info_columns: DataInfoColumn = bl_cache.BLField( info_columns: DataInfoColumn = bl_cache.BLField(
{DataInfoColumn.MathType, DataInfoColumn.Length}, prop_ui=True, enum_many=True {DataInfoColumn.MathType, DataInfoColumn.Unit}, prop_ui=True, enum_many=True
) )
#################### ####################
@ -71,8 +75,10 @@ class DataBLSocket(base.MaxwellSimSocket):
# - UI # - UI
#################### ####################
def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None: def draw_input_label_row(self, row: bpy.types.UILayout, text) -> None:
if self.format == 'jax':
row.label(text=text) row.label(text=text)
info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
row.prop(self, self.blfields['info_columns']) row.prop(self, self.blfields['info_columns'])
row.prop( row.prop(
self, self,
@ -83,7 +89,8 @@ class DataBLSocket(base.MaxwellSimSocket):
) )
def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None: def draw_output_label_row(self, row: bpy.types.UILayout, text) -> None:
if self.format == 'jax': info = self.compute_data(kind=ct.FlowKind.Info)
if not ct.FlowSignal.check(info) and self.format == 'jax' and info.dim_names:
row.prop( row.prop(
self, self,
self.blfields['show_info_columns'], self.blfields['show_info_columns'],
@ -92,6 +99,7 @@ class DataBLSocket(base.MaxwellSimSocket):
icon=ct.Icon.ToggleSocketInfo, icon=ct.Icon.ToggleSocketInfo,
) )
row.prop(self, self.blfields['info_columns']) row.prop(self, self.blfields['info_columns'])
row.label(text=text) row.label(text=text)
def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None: def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None:
@ -118,16 +126,27 @@ class DataBLSocket(base.MaxwellSimSocket):
grid.label(text=spux.sp_to_str(dim_idx.unit)) grid.label(text=spux.sp_to_str(dim_idx.unit))
# Outputs # Outputs
for output_name in info.output_names: grid.label(text=info.output_name)
grid.label(text=output_name)
if DataInfoColumn.Length in self.info_columns: if DataInfoColumn.Length in self.info_columns:
grid.label(text='', icon=ct.Icon.DataSocketOutput) grid.label(text='', icon=ct.Icon.DataSocketOutput)
if DataInfoColumn.MathType in self.info_columns: if DataInfoColumn.MathType in self.info_columns:
grid.label( grid.label(
text=spux.MathType.to_str(info.output_mathtypes[output_name]) text=(
spux.MathType.to_str(info.output_mathtype)
+ (
'ˣ'.join(
[
unicode_superscript(out_axis)
for out_axis in info.output_shape
]
)
if info.output_shape
else ''
)
)
) )
if DataInfoColumn.Unit in self.info_columns: if DataInfoColumn.Unit in self.info_columns:
grid.label(text=spux.sp_to_str(info.output_units[output_name])) grid.label(text=f'{spux.sp_to_str(info.output_unit)}')
#################### ####################
@ -137,9 +156,11 @@ class DataSocketDef(base.SocketDef):
socket_type: ct.SocketType = ct.SocketType.Data socket_type: ct.SocketType = ct.SocketType.Data
format: typ.Literal['xarray', 'jax', 'monitor_data'] format: typ.Literal['xarray', 'jax', 'monitor_data']
default_show_info_columns: bool = True
def init(self, bl_socket: DataBLSocket) -> None: def init(self, bl_socket: DataBLSocket) -> None:
bl_socket.format = self.format bl_socket.format = self.format
bl_socket.default_show_info_columns = self.default_show_info_columns
#################### ####################

View File

@ -86,9 +86,7 @@ class ExprBLSocket(base.MaxwellSimSocket):
def lazy_value_func(self) -> ct.LazyValueFuncFlow: def lazy_value_func(self) -> ct.LazyValueFuncFlow:
return ct.LazyValueFuncFlow( return ct.LazyValueFuncFlow(
func=sp.lambdify(self.symbols, self.value, 'jax'), func=sp.lambdify(self.symbols, self.value, 'jax'),
func_args=[ func_args=[spux.sympy_to_python_type(sym) for sym in self.symbols],
(sym.name, spux.sympy_to_python_type(sym)) for sym in self.symbols
],
supports_jax=True, supports_jax=True,
) )

View File

@ -31,6 +31,18 @@ class MathType(enum.StrEnum):
Real = enum.auto() Real = enum.auto()
Complex = enum.auto() Complex = enum.auto()
def combine(*mathtypes: list[typ.Self]) -> typ.Self:
if MathType.Complex in mathtypes:
return MathType.Complex
elif MathType.Real in mathtypes:
return MathType.Real
elif MathType.Rational in mathtypes:
return MathType.Rational
elif MathType.Integer in mathtypes:
return MathType.Integer
elif MathType.Bool in mathtypes:
return MathType.Bool
@staticmethod @staticmethod
def from_expr(sp_obj: SympyType) -> type: def from_expr(sp_obj: SympyType) -> type:
if isinstance(sp_obj, sp.logic.boolalg.Boolean): if isinstance(sp_obj, sp.logic.boolalg.Boolean):
@ -595,6 +607,21 @@ ComplexNumber: typ.TypeAlias = ConstrSympyExpr(
) )
Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber Number: typ.TypeAlias = IntNumber | RealNumber | ComplexNumber
# Number
PhysicalRealNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real'},
allowed_structures={'scalar'},
)
PhysicalComplexNumber: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False,
allow_units=True,
allowed_sets={'integer', 'rational', 'real', 'complex'},
allowed_structures={'scalar'},
)
PhysicalNumber: typ.TypeAlias = PhysicalRealNumber | PhysicalComplexNumber
# Vector # Vector
Real3DVector: typ.TypeAlias = ConstrSympyExpr( Real3DVector: typ.TypeAlias = ConstrSympyExpr(
allow_variables=False, allow_variables=False,

View File

@ -117,8 +117,8 @@ def rgba_image_from_2d_map(
def plot_hist_1d( def plot_hist_1d(
data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' size'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
ax.hist(data, bins=30, alpha=0.75) ax.hist(data, bins=30, alpha=0.75)
ax.set_title('Histogram') ax.set_title('Histogram')
@ -130,8 +130,8 @@ def plot_box_plot_1d(
data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis data: jtyp.Float32[jtyp.Array, ' heights'], info, ax: mpl_ax.Axis
) -> None: ) -> None:
x_name = info.dim_names[0] x_name = info.dim_names[0]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
ax.boxplot(data) ax.boxplot(data)
ax.set_title('Box Plot') ax.set_title('Box Plot')
@ -147,8 +147,8 @@ def plot_curve_2d(
x_name = info.dim_names[0] x_name = info.dim_names[0]
x_unit = info.dim_units[x_name] x_unit = info.dim_units[x_name]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
times.append(time.perf_counter() - times[0]) times.append(time.perf_counter() - times[0])
ax.plot(info.dim_idx_arrays[0], data) ax.plot(info.dim_idx_arrays[0], data)
@ -167,8 +167,8 @@ def plot_points_2d(
) -> None: ) -> None:
x_name = info.dim_names[0] x_name = info.dim_names[0]
x_unit = info.dim_units[x_name] x_unit = info.dim_units[x_name]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6) ax.scatter(info.dim_idx_arrays[0], data, alpha=0.6)
ax.set_title('2D Points') ax.set_title('2D Points')
@ -179,8 +179,8 @@ def plot_points_2d(
def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None: def plot_bar(data: jtyp.Float32[jtyp.Array, ' points'], info, ax: mpl_ax.Axis) -> None:
x_name = info.dim_names[0] x_name = info.dim_names[0]
x_unit = info.dim_units[x_name] x_unit = info.dim_units[x_name]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
ax.bar(info.dim_idx_arrays[0], data, alpha=0.7) ax.bar(info.dim_idx_arrays[0], data, alpha=0.7)
ax.set_title('2D Bar') ax.set_title('2D Bar')
@ -194,8 +194,8 @@ def plot_curves_2d(
) -> None: ) -> None:
x_name = info.dim_names[0] x_name = info.dim_names[0]
x_unit = info.dim_units[x_name] x_unit = info.dim_units[x_name]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
for category in range(data.shape[1]): for category in range(data.shape[1]):
ax.plot(data[:, 0], data[:, 1]) ax.plot(data[:, 0], data[:, 1])
@ -211,8 +211,8 @@ def plot_filled_curves_2d(
) -> None: ) -> None:
x_name = info.dim_names[0] x_name = info.dim_names[0]
x_unit = info.dim_units[x_name] x_unit = info.dim_units[x_name]
y_name = info.output_names[0] y_name = info.output_name
y_unit = info.output_units[y_name] y_unit = info.output_unit
ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1]) ax.fill_between(info.dim_arrays[0], data[:, 0], info.dim_arrays[0], data[:, 1])
ax.set_title('2D Curves') ax.set_title('2D Curves')